def forward(self, query, user_behavior): # query ad : size -> batch_size * 1 * embedding_size # user behavior : size -> batch_size * time_seq_len * embedding_size user_behavior_len = user_behavior.size(1)
queries = query.expand(-1, user_behavior_len, -1)attention_input = torch.cat([queries, user_behavior, queries - user_behavior, queries * user_behavior],dim=-1)# as the source code, subtraction simulates verctors' differenceattention_output = self.dnn(attention_input)attention_score = self.dense(attention_output)# [B, T, 1]return attention_score其实这一段代码就是attention网络的核心 , 它生成的是attention中最重要的权重 。权重有了之后 , 我们只需要将它和用户行为序列的embedding相乘 。利用矩阵乘法的特性 , 一个[B, 1. T]的矩阵乘上一个[B, T, E]的矩阵 , 会得到[B, 1, E]的结果 。这个相乘之后的结果其实就是我们需要的加权求和 , 只不过是通过矩阵乘法来实现了 。
我们再看下源码加深一下理解:
class AttentionSequencePoolingLayer(nn.Module): """The Attentional sequence pooling operation used in DIN & DIEN.
Arguments- **att_hidden_units**:list of positive integer, the attention.NET layer number and units in each layer.- **att_activation**: Activation function to use in attention net.- **weight_normalization**: bool.Whether normalize the attention score of local activation unit.- **supports_masking**:If True,the input need to support masking.References- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf)"""def __init__(self, att_hidden_units=(80, 40), att_activation='sigmoid', weight_normalization=False,return_score=False, supports_masking=False, embedding_dim=4, **kwargs):super(AttentionSequencePoolingLayer, self).__init__()self.return_score = return_scoreself.weight_normalization = weight_normalizationself.supports_masking = supports_maskingself.local_att = LocalActivationUnit(hidden_units=att_hidden_units, embedding_dim=embedding_dim,activation=att_activation,dropout_rate=0, use_bn=False)[docs] def forward(self, query, keys, keys_length, mask=None): """ Input shape - A list of three tensor: [query,keys,keys_length]
【从理论到实现,手把手实现Attention网络】- query is a 3D tensor with shape:``(batch_size, 1, embedding_size)``- keys is a 3D tensor with shape:``(batch_size, T, embedding_size)``- keys_length is a 2D tensor with shape: ``(batch_size, 1)``Output shape- 3D tensor with shape: ``(batch_size, 1, embedding_size)``."""batch_size, max_length, _ = keys.size()# Maskif self.supports_masking:if mask is None:rAIse ValueError("When supports_masking=True,input must support masking")keys_masks = mask.unsqueeze(1)else:keys_masks = torch.arange(max_length, device=keys_length.device, dtype=keys_length.dtype).repeat(batch_size,1)# [B, T]keys_masks = keys_masks < keys_length.view(-1, 1)# 0, 1 maskkeys_masks = keys_masks.unsqueeze(1)# [B, 1, T]attention_score = self.local_att(query, keys)# [B, T, 1]outputs = torch.transpose(attention_score, 1, 2)# [B, 1, T]if self.weight_normalization:paddings = torch.ones_like(outputs) * (-2 ** 32 + 1)else:paddings = torch.zeros_like(outputs)outputs = torch.where(keys_masks, outputs, paddings)# [B, 1, T]# Scale# outputs = outputs / (keys.shape[-1] ** 0.05)if self.weight_normalization:outputs = F.softmax(outputs, dim=-1)# [B, 1, T]if not self.return_score:# Weighted sumoutputs = torch.matmul(outputs, keys)# [B, 1, E]return outputs这段代码当中加入了mask以及normalization等逻辑 , 全部忽略掉的话 , 真正核心的主干代码就只有三行:
attention_score = self.local_att(query, keys) # [B, T, 1] outputs = torch.transpose(attention_score, 1, 2) # [B, 1, T] outputs = torch.matmul(outputs, keys) # [B, 1, E] 到这里我们关于attention网络的实现方法就算是讲完了 , 对于DIN这篇论文也就理解差不多了 , 不过还有一个细节值得聊聊 。就是关于attention权重的问题 。
在DIN这篇论文当中 , 我们是使用了一个单独的LocalActivationUnit来学习的两个embedding拼接之后的权重 , 也就是上图代码当中这个部分:

文章插图
图片
我们通过一个单独的神经网络来对两个向量打分给出权重 , 这个权重的运算逻辑并不一定是根据向量的相似度来计算的 。毕竟神经网络是一个黑盒 , 我们无从猜测内部逻辑 。只不过从逻辑上或者经验上来说 , 我们更倾向于它是根据向量的相似度来计算的 。
推荐阅读
- AI七十年,从一篇论文到一个世界
- 中签名额价值百万,K-Pop签售会到底有多赚?
- 李凯尔到底什么水平?他能将中国男篮带到怎样的高度?
- 到退休年龄去办理退休,为何查不到档案?没有档案,能办退休吗?
- 《长风渡》停播,让我看到影视圈一股隐藏的“乱象”,在野蛮生长
- 挖呀挖黄老师现身小杨哥直播间,出场费120W,网友直呼听到就想吐
- 张紫妍被潜后自尽,被曝曾接待4个财阀被玩到不能走路
- 我是怎么从小孩儿的保姆专业进的互联网行业?
- 刘嘉玲闺蜜为梁朝伟庆生引争议,又亲又抱,大腿缠到裤裆处
- 夏季钓鱼,避开3种天气2种钓位,“空军”找不到你
