小熊科技|PyTorch文本:01.聊天机器人教程(11)


最后 , 如果输入的句子包含一个不在词汇表中的单词 , 我们会通过打印错误消息并提示用户输入另一个句子来优雅地处理 。
outputs:GRU最后一个隐藏层的输出特征(双向输出之和);shape =(max_length , batch_size , hidden_size)hidden:从GRU更新隐藏状态;shape =(n_layers x num_directions , batch_size , hidden_size)class EncoderRNN(nn.Module):def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):super(EncoderRNN, self).__init__()self.n_layers = n_layersself.hidden_size = hidden_sizeself.embedding = embedding# 初始化GRU; input_size和hidden_size参数都设置为'hidden_size'# 因为我们的输入大小是一个嵌入了多个特征的单词== hidden_sizeself.gru = nn.GRU(hidden_size, hidden_size, n_layers,dropout=(0 if n_layers == 1 else dropout), bidirectional=True)def forward(self, input_seq, input_lengths, hidden=None):# 将单词索引转换为词向量embedded = self.embedding(input_seq)# 为RNN模块打包填充batch序列packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)# 正向通过GRUoutputs, hidden = self.gru(packed, hidden)# 打开填充outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)# 总和双向GRU输出outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]# 返回输出和最终隐藏状态return outputs, hidden7.运行模型最后 , 是时候运行我们的模型了!
无论我们是否想要训练或测试聊天机器人模型 , 我们都必须初始化各个编码器和解码器模型 。 在接下来的部分 , 我们设置所需要的配置 , 选择 从头开始或设置检查点以从中加载 , 并构建和初始化模型 。 您可以随意使用不同的配置来优化性能 。
# 配置模型model_name = 'cb_model'attn_model = 'dot'#attn_model = 'general'#attn_model = 'concat'hidden_size = 500encoder_n_layers = 2decoder_n_layers = 2dropout = 0.1batch_size = 64# 设置检查点以加载; 如果从头开始 , 则设置为NoneloadFilename = Nonecheckpoint_iter = 4000#loadFilename = os.path.join(save_dir, model_name, corpus_name,#'{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),#'{}_checkpoint.tar'.format(checkpoint_iter))# 如果提供了loadFilename , 则加载模型if loadFilename:# 如果在同一台机器上加载 , 则对模型进行训练checkpoint = torch.load(loadFilename)# If loading a model trained on GPU to CPU#checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))encoder_sd = checkpoint['en']decoder_sd = checkpoint['de']encoder_optimizer_sd = checkpoint['en_opt']decoder_optimizer_sd = checkpoint['de_opt']embedding_sd = checkpoint['embedding']voc.__dict__ = checkpoint['voc_dict']print('Building encoder and decoder ...')# 初始化词向量embedding = nn.Embedding(voc.num_words, hidden_size)if loadFilename:embedding.load_state_dict(embedding_sd)# 初始化编码器Percent complete: 0.0%; Average loss: 8.9717Iteration: 2; Percent complete: 0.1%; Average loss: 8.8521Iteration: 3; Percent complete: 0.1%; Average loss: 8.6360Iteration: 4; Percent complete: 0.1%; Average loss: 8.4234Iteration: 5; Percent complete: 0.1%; Average loss: 7.9403Iteration: 6; Percent complete: 0.1%; Average loss: 7.3892Iteration: 7; Percent complete: 0.2%; Average loss: 7.0589Iteration: 8; Percent complete: 0.2%; Average loss: 7.0130Iteration: 9; Percent complete: 0.2%; Average loss: 6.7383Iteration: 10; Percent complete: 0.2%; Average loss: 6.5343...Iteration: 3991; Percent complete: 99.8%; Average loss: 2.6607Iteration: 3992; Percent complete: 99.8%; Average loss: 2.6188Iteration: 3993; Percent complete: 99.8%; Average loss: 2.8319Iteration: 3994; Percent complete: 99.9%; Average loss: 2.5817Iteration: 3995; Percent complete: 99.9%; Average loss: 2.4979Iteration: 3996; Percent complete: 99.9%; Average loss: 2.7317Iteration: 3997; Percent complete: 99.9%; Average loss: 2.5969Iteration: 3998; Percent complete: 100.0%; Average loss: 2.2275Iteration: 3999; Percent complete: 100.0%; Average loss: 2.7124Iteration: 4000; Percent complete: 100.0%; Average loss: 2.5975


推荐阅读