将26个token压缩成1个,新方法极致节省ChatGPT输入框空间( 二 )


通过掩膜学习 Gisting上文描述了 Gisting 的一般框架,接下来将探讨一种学习此类模型的极简单方法:使用 LM 本身用作 Gist 预测器 G 。这不仅利用了 LM 中的预存在知识,而且允许通过简单地执行标准指令微调来学习 gisting 并修改 Transformer 注意力掩膜来增强 prompt 压缩 。这意味着 Gisting 不会产生额外训练成本,只需要基于标准指令微调即可!
具体来说,向模型词汇表和嵌入矩阵中添加一个特殊的 gist token,类似于此类模型中常见的句子开头 / 结尾 token 。然后对于给定的(任务,输入)元组(t,x),使用 (t, g_1, . . . , g_k, x) 中一组 k 个连续的 gist token 将 t 和 x 连接在一起,例如

将26个token压缩成1个,新方法极致节省ChatGPT输入框空间

文章插图
。这个序列被输入到模型中,有一个限制,即在 gist token 之后的输入 token 不能参考之前的 prompt token(但它们可以参考 gist token) 。这会强制模型将 prompt 中的信息压缩成 gist token,因为输入 x (输出 y) 无法处理 prompt t 。 
下图 2 展示了所需要的更改 。对于 GPT-3 或 LLaMA 等通常采用自回归因果注意力掩膜的 decoder-only LM,只需 mask out 图 2a 所示的三角形左下角 。对于具有双向编码器和自回归解码器的 encoder-decoder LM,则需要进行两项修改(图 2b 所示) 。
首先,在通常没有掩膜的编码器中,阻止输入 token x 参考 prompt token t 。但还必须防止 prompt t 和 gist token g_i 参考输入 token x,否则编码器将根据输入学习不同的 gist 表示 。最后解码器正常运行,除了在交叉注意力期间,这时需要阻止解码器参考 prompt token t 。
 
将26个token压缩成1个,新方法极致节省ChatGPT输入框空间

文章插图
 
实验结果对于不同数量的 gist token,LLaMA-7B 和 FLAN-T5-XXL 的 ROUGE-L 和 ChatGPT 评估结果如下图 3 所示 。
 
将26个token压缩成1个,新方法极致节省ChatGPT输入框空间

文章插图
 
模型通常对 gist token 的数量 k 不敏感:将 prompt 压缩到单个 token 并不会导致显著性能下降 。事实上,在某些情况下,过多的 gist token 会损害性能 (例如 LLaMA-7B, 10 gist tokens),这可能是因为增加的容量使训练分布过拟合 。因此,研究者在下表 1 中给出了单 token 模型的具体数值,并在剩余实验中使用单个 gist 模型 。
 
将26个token压缩成1个,新方法极致节省ChatGPT输入框空间

文章插图
 
在见过的指令上,gist 模型获得了与其对应阳性对照模型几乎相同的 ROUGE 和 ChatGPT 性能,在 LLaMA-7B FLANT5-XXL 上的胜率分别为 48.6% 和 50.8% 。这里研究者最感兴趣的是它们在未见过任务上的泛化能力,这需要通过另外两个数据集来衡量的 。
在 Alpaca+ 训练数据集中未见过的 prompt 中,可以看到 gist 模型在未见过 prompt 上有着强大的泛化能力:与对照组相比,分别有 49.7%(LLaMA)和 46.2%(FLAN-T5)的胜率 。在最具挑战性的 OOD Human split 上,gist 模型的胜率略微下降,分别为 45.8%(LLaMA)和 42.5%(FLANT5) 。
本文的目的是让 gist 模型紧密地模仿原始模型的功能,因此有人可能会问究竟什么时候 gist 模型与对照组无差别 。下图 4 说明了这种情况发生的频率:对于已见过任务(但是未见过的输入),gist 模型几乎有一半的时间与对照组不相上下 。对于未见过的任务,这一数字下降到了 20-25% 。对于 OOD Human 任务,这一数字又下降到 10% 。无论如何,gist 模型输出的质量是很高的 。
 
将26个token压缩成1个,新方法极致节省ChatGPT输入框空间

文章插图
 
总的来说,这些结果表明,gist 模型可以可靠地压缩 prompt,甚至在训练分布之外的某些 prompt 上也可以做到这一点,特别是像 LLaMA 这样的 decoder-only 因果 LM 。FLAN-T5 等 encoder-decoder 模型表现略差,一个可能的原因是 gist 掩膜抑制了编码器中的双向注意力流,这比仅 mask 自回归解码器的一部分 history 更具挑战性 。未来需要进一步的工作来研究这个假设 。
计算、内存和存储效率最后,回到这项工作的核心动机之一:gisting 可以带来什么样的效率提升?
下表 2 展示了使用 PyTorch 2.0 分析器对模型进行单次前向传递的结果(即使用单个输入 token 的自回归解码的一步),并对 Human eval split 中的 252 个指令取平均值 。与未经优化的模型相比,gist 缓存显著提高了效率 。两种模型的 FLOPs 节约率达到了 40%,时钟时间降低了 4-7% 。
 
将26个token压缩成1个,新方法极致节省ChatGPT输入框空间


推荐阅读