小熊科技|PyTorch文本:01.聊天机器人教程( 十 )

6.评估定义在训练模型后 , 我们希望能够自己与机器人交谈 。 首先 , 我们必须定义我们希望模型如何解码编码输入 。
6.1 贪婪解码贪婪解码是我们在不使用 teacher forcing 时在训练期间使用的解码方法 。 换句话说 , 对于每一步 , 我们只需从具有最高 softmax 值的 decoder_output 中选择单词 。 该解码方法在单步长级别上是最佳的 。
为了便于贪婪解码操作 , 我们定义了一个GreedySearchDecoder类 。 当运行时 , 类的实例化对象输入序列(input_seq)的大小是(input_seq length , 1) ,标量输入(input_length)长度的张量和 max_length 来约束响应句子长度 。 使用以下计算图来评估输入句子:
计算图
1.通过编码器模型前向计算 。
2.准备编码器的最终隐藏层 , 作为解码器的第一个隐藏输入 。
3.将解码器的第一个输入初始化为 SOS_token 。
4.将初始化张量追加到解码后的单词中 。
5.一次迭代解码一个单词token:
?(i)通过解码器进行前向计算 。
?(ii)获得最可能的单词token及其softmax分数 。
?(iii)记录token和分数 。
?(iv)准备当前token作为下一个解码器的输入 。
6.返回收集到的词 tokens 和分数 。
class GreedySearchDecoder(nn.Module):def __init__(self, encoder, decoder):super(GreedySearchDecoder, self).__init__()self.encoder = encoderself.decoder = decoderdef forward(self, input_seq, input_length, max_length):# 通过编码器模型转发输入encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)# 准备编码器的最终隐藏层作为解码器的第一个隐藏输入decoder_hidden = encoder_hidden[:decoder.n_layers]# 使用SOS_token初始化解码器输入decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token# 初始化张量以将解码后的单词附加到all_tokens = torch.zeros([0], device=device, dtype=torch.long)all_scores = torch.zeros([0], device=device)# 一次迭代地解码一个词tokensfor _ in range(max_length):# 正向通过解码器decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)# 获得最可能的单词标记及其softmax分数decoder_scores, decoder_input = torch.max(decoder_output, dim=1)# 记录token和分数all_tokens = torch.cat((all_tokens, decoder_input), dim=0)all_scores = torch.cat((all_scores, decoder_scores), dim=0)# 准备当前令牌作为下一个解码器输入(添加维度)decoder_input = torch.unsqueeze(decoder_input, 0)# 返回收集到的词tokens和分数return all_tokens, all_scores6.2 评估我们的文本现在我们已经定义了解码方法 , 我们可以编写用于评估字符串输入句子的函数 。 evaluate函数管理输入句子的低层级处理过程 。 我们首先使 用batch_size == 1将句子格式化为输入batch的单词索引 。我们通过将句子的单词转换为相应的索引 , 并通过转换维度来为我们的模型准备 张量 。 我们还创建了一个lengths张量 , 其中包含输入句子的长度 。 在这种情况下 , lengths是标量 , 因为我们一次只评估一个句子(batch_size == 1) 。接下来 , 我们使用我们的GreedySearchDecoder实例化后的对象(searcher)获得解码响应句子的张量 。 最后 , 我们将响应的索引转换为单 词并返回已解码单词的列表 。
evaluateInput充当聊天机器人的用户接口 。 调用时 , 将生成一个输入文本字段 , 我们可以在其中输入查询语句 。 在输入我们的输入句子并 按 Enter 后 , 我们的文本以与训练数据相同的方式标准化 , 并最终被输入到评估函数以获得解码的输出句子 。 我们循环这个过程 , 这样我们可 以继续与我们的机器人聊天直到我们输入“q”或“quit” 。


推荐阅读