系列文章
【如何訓(xùn)練一個中英翻譯模型】LSTM機器翻譯seq2seq字符編碼(一)
【如何訓(xùn)練一個中英翻譯模型】LSTM機器翻譯模型訓(xùn)練與保存(二)
【如何訓(xùn)練一個中英翻譯模型】LSTM機器翻譯模型部署(三)
【如何訓(xùn)練一個中英翻譯模型】LSTM機器翻譯模型部署之onnx(python)(四)
模型部署也是很重要的一部分,這里先講基于python的部署,后面我們還要將模型部署到移動端。
細心的小伙伴會發(fā)現(xiàn)前面的文章在模型保存之后進行模型推理時,我們使用的數(shù)據(jù)是在訓(xùn)練之前我們對數(shù)據(jù)進行處理的encoder_input_data中讀取,而不是我們手動輸入的,那么這一章主要來解決自定義輸入推理的問題
1、加載字符文件
首先,我們根據(jù) 【如何訓(xùn)練一個中譯英翻譯器】LSTM機器翻譯模型訓(xùn)練與保存(二)的操作,到最后
會得到這樣的三個文件:input_words.txt,target_words.txt,config.json
需要逐一進行加載
進行加載
# 加載字符
# 從 input_words.txt 文件中讀取字符串
with open('input_words.txt', 'r') as f:
input_words = f.readlines()
input_characters = [line.rstrip('\n') for line in input_words]
# 從 target_words.txt 文件中讀取字符串
with open('target_words.txt', 'r', newline='') as f:
target_words = [line.strip() for line in f.readlines()]
target_characters = [char.replace('\\t', '\t').replace('\\n', '\n') for char in target_words]
#字符處理,以方便進行編碼
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])
# something readable.
reverse_input_char_index = dict(
(i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
(i, char) for char, i in target_token_index.items())
num_encoder_tokens = len(input_characters) # 英文字符數(shù)量
num_decoder_tokens = len(target_characters) # 中文文字數(shù)量
讀取配置文件
import json
with open('config.json', 'r') as file:
loaded_data = json.load(file)
# 從加載的數(shù)據(jù)中獲取max_encoder_seq_length和max_decoder_seq_length的值
max_encoder_seq_length = loaded_data["max_encoder_seq_length"]
max_decoder_seq_length = loaded_data["max_decoder_seq_length"]
2、加載權(quán)重文件
# 加載權(quán)重
from keras.models import load_model
encoder_model = load_model('encoder_model.h5')
decoder_model = load_model('decoder_model.h5')
3、推理模型搭建
def decode_sequence(input_seq):
# Encode the input as state vectors.
states_value = encoder_model.predict(input_seq)
# Generate empty target sequence of length 1.
target_seq = np.zeros((1, 1, num_decoder_tokens))
# Populate the first character of target sequence with the start character.
target_seq[0, 0, target_token_index['\t']] = 1.
# this target_seq you can treat as initial state
# Sampling loop for a batch of sequences
# (to simplify, here we assume a batch of size 1).
stop_condition = False
decoded_sentence = ''
while not stop_condition:
output_tokens, h, c = decoder_model.predict([target_seq] + states_value)
# Sample a token
# argmax: Returns the indices of the maximum values along an axis
# just like find the most possible char
sampled_token_index = np.argmax(output_tokens[0, -1, :])
# find char using index
sampled_char = reverse_target_char_index[sampled_token_index]
# and append sentence
decoded_sentence += sampled_char
# Exit condition: either hit max length
# or find stop character.
if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):
stop_condition = True
# Update the target sequence (of length 1).
# append then ?
# creating another new target_seq
# and this time assume sampled_token_index to 1.0
target_seq = np.zeros((1, 1, num_decoder_tokens))
target_seq[0, 0, sampled_token_index] = 1.
# Update states
# update states, frome the front parts
states_value = [h, c]
return decoded_sentence
4、進行推理
import numpy as np
input_text = "Call me."
encoder_input_data = np.zeros(
(1,max_encoder_seq_length, num_encoder_tokens),
dtype='float32')
for t, char in enumerate(input_text):
print(char)
# 3D vector only z-index has char its value equals 1.0
encoder_input_data[0,t, input_token_index[char]] = 1.
input_seq = encoder_input_data
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_text)
print('Decoded sentence:', decoded_sentence)
運行結(jié)果:文章來源:http://www.zghlxwxcb.cn/news/detail-606758.html
以上的代碼可在kaggle上運行:how-to-train-a-chinese-to-english-translator-iii文章來源地址http://www.zghlxwxcb.cn/news/detail-606758.html
到了這里,關(guān)于【如何訓(xùn)練一個中英翻譯模型】LSTM機器翻譯模型部署(三)的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!