神经网络中的蒸馏技术,从Softmax开始说起( 二 )


扩展Softmax这些弱概率的问题是,它们没有捕捉到学生模型有效学习所需的信息 。例如,如果概率分布像[0.99, 0.01],几乎不可能传递图像具有数字7的特征的知识 。
Hinton等人解决这个问题的方法是,在将原始logits传递给softmax之前,将教师模型的原始logits按一定的温度进行缩放 。这样,就会在可用的类标签中得到更广泛的分布 。然后用同样的温度用于训练学生模型 。
我们可以把学生模型的修正损失函数写成这个方程的形式:

神经网络中的蒸馏技术,从Softmax开始说起

文章插图
 
其中,pi是教师模型得到软概率分布,si的表达式为:
神经网络中的蒸馏技术,从Softmax开始说起

文章插图
 
def get_kd_loss(student_logits, teacher_logits,                 true_labels, temperature,                alpha, beta):        teacher_probs = tf.nn.softmax(teacher_logits / temperature)    kd_loss = tf.keras.losses.categorical_crossentropy(        teacher_probs, student_logits / temperature,         from_logits=True)        return kd_loss使用扩展Softmax来合并硬标签Hinton等人还探索了在真实标签(通常是独热编码)和学生模型的预测之间使用传统交叉熵损失的想法 。当训练数据集很小,并且软标签没有足够的信号供学生模型采集时,这一点尤其有用 。
当它与扩展的softmax相结合时,这种方法的工作效果明显更好,而整体损失函数成为两者之间的加权平均 。
神经网络中的蒸馏技术,从Softmax开始说起

文章插图
 
def get_kd_loss(student_logits, teacher_logits,                 true_labels, temperature,                alpha, beta):    teacher_probs = tf.nn.softmax(teacher_logits / temperature)    kd_loss = tf.keras.losses.categorical_crossentropy(        teacher_probs, student_logits / temperature,         from_logits=True)        ce_loss = tf.keras.losses.sparse_categorical_crossentropy(        true_labels, student_logits, from_logits=True)        total_loss = (alpha * kd_loss) + (beta * ce_loss)    return total_loss / (alpha + beta)建议β的权重小于α 。
在原始Logits上进行操作Caruana等人操作原始logits,而不是softmax值 。这个工作流程如下:
  • 这部分保持相同 —— 训练一个教师模型 。这里交叉熵损失将根据数据集中的真实标签计算 。
  • 现在,为了训练学生模型,训练目标变成分别最小化来自教师和学生模型的原始对数之间的平均平方误差 。

神经网络中的蒸馏技术,从Softmax开始说起

文章插图
 
mse = tf.keras.losses.MeanSquaredError()def mse_kd_loss(teacher_logits, student_logits):    return mse(teacher_logits, student_logits)使用这个损失函数的一个潜在缺点是它是无界的 。原始logits可以捕获噪声,而一个小模型可能无法很好的拟合 。这就是为什么为了使这个损失函数很好地适合蒸馏状态,学生模型需要更大一点 。
Tang等人探索了在两个损失之间插值的想法:扩展softmax和MSE损失 。数学上,它看起来是这样的:
神经网络中的蒸馏技术,从Softmax开始说起

文章插图
 
根据经验,他们发现当α = 0时,(在NLP任务上)可以获得最佳的性能 。
如果你在这一点上感到有点不知怎么办,不要担心 。希望通过代码,事情会变得清楚 。
一些训练方法在本节中,我将向你提供一些在使用知识蒸馏时可以考虑的训练方法 。
使用数据增强他们在NLP数据集上展示了这个想法,但这也适用于其他领域 。为了更好地指导学生模型训练,使用数据增强会有帮助,特别是当你处理的数据较少的时候 。因为我们通常保持学生模型比教师模型小得多,所以我们希望学生模型能够获得更多不同的数据,从而更好地捕捉领域知识 。


推荐阅读