小熊科技|PyTorch文本:01.聊天机器人教程( 三 )

2.2 加载和清洗数据我们下一个任务是创建词汇表并将查询/响应句子对(对话)加载到内存 。
注意我们正在处理词序 , 这些词序没有映射到离散数值空间 。 因此 , 我们必须通过数据集中的单词来创建一个索引 。
为此我们创建了一个Voc类,它会存储从单词到索引的映射、索引到单词的反向映射、每个单词的计数和总单词量 。 这个类提供向词汇表中添 加单词的方法(addWord)、添加所有单词到句子中的方法 (addSentence) 和清洗不常见的单词方法(trim) 。 更多的数据清洗在后面进行 。
# 默认词向量PAD_token = 0# Used for padding short sentencesSOS_token = 1# Start-of-sentence tokenEOS_token = 2# End-of-sentence tokenclass Voc:def __init__(self, name):self.name = nameself.trimmed = Falseself.word2index = {}self.word2count = {}self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}self.num_words = 3# Count SOS, EOS, PADdef addSentence(self, sentence):for word in sentence.split(' '):self.addWord(word)def addWord(self, word):if word not in self.word2index:self.word2index[word] = self.num_wordsself.word2count[word] = 1self.index2word[self.num_words] = wordself.num_words += 1else:self.word2count[word] += 1# 删除低于特定计数阈值的单词def trim(self, min_count):if self.trimmed:returnself.trimmed = Truekeep_words = []for k, v in self.word2count.items():if v >= min_count:keep_words.append(k)print('keep_words {} / {} = {:.4f}'.format(len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)))# 重初始化字典self.word2index = {}self.word2count = {}self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}self.num_words = 3 # Count default tokensfor word in keep_words:self.addWord(word)现在我们可以组装词汇表和查询/响应语句对 。 在使用数据之前 , 我们必须做一些预处理 。
首先 , 我们必须使用unicodeToAscii将 unicode 字符串转换为 ASCII 。 然后 , 我们应该将所有字母转换为小写字母并清洗掉除基本标点之 外的所有非字母字符 (normalizeString) 。 最后 , 为了帮助训练收敛 , 我们将过滤掉长度大于MAX_LENGTH 的句子 (filterPairs) 。
MAX_LENGTH = 10# Maximum sentence length to consider# 将Unicode字符串转换为纯ASCII , 多亏了# def unicodeToAscii(s):return ''.join(c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c) != 'Mn')# 初始化Voc对象 和 格式化pairs对话存放到list中def readVocs(datafile, corpus_name):print("Reading lines...")# Read the file and split into lineslines = open(datafile, encoding='utf-8').read().strip().split('\n')# Split every line into pairs and normalizepairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]voc = Voc(corpus_name)return voc, pairs# 如果对 'p' 中的两个句子都低于 MAX_LENGTH 阈值 , 则返回Truedef filterPair(p):# Input sequences need to preserve the last word for EOS tokenreturn len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH# 过滤满足条件的 pairs 对话def filterPairs(pairs):return [pair for pair in pairs if filterPair(pair)]# 使用上面定义的函数 , 返回一个填充的voc对象和对列表def loadPrepareData(corpus, corpus_name, datafile, save_dir):print("Start preparing training data ...")voc, pairs = readVocs(datafile, corpus_name)print("Read {!s} sentence pairs".format(len(pairs)))pairs = filterPairs(pairs)print("Trimmed to {!s} sentence pairs".format(len(pairs)))print("Counting words...")for pair in pairs:voc.addSentence(pair[0])voc.addSentence(pair[1])print("Counted words:", voc.num_words)return voc, pairs# 加载/组装voc和对save_dir = os.path.join("data", "save")voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)# 打印一些对进行验证print("\npairs:")for pair in pairs[:10]:print(pair)


推荐阅读