profileName: youpingfang postId: 293 postType: post categories:
- 6
简介:受@3Blue1Brown 的可视化数学系列启发的第三期,这期Attention机制的视频用来作为之后CLIP原理可视化视频的必要铺垫。本期将以可视化数学演示来带大家无门槛直观的理解transformer模型和它背后的attention机制的核心原理。
一、为什么需要Attention机制
1.1 神经网络的基本工作原理
神经网络的核心运作模式是将数字信息流经网络各节点,最终得到输出结果。对于非数字类型的输入输出任务,需要先将输入转化为数字,再输入神经网络,最后将输出的数字转化为目标数据类型。
1.2 信息规模与计算资源的矛盾
当我们增大输入信息规模时,神经网络也必须增大规模以适配更大的输入。然而,神经网络规模的增大速度远快于输入信息规模的增长速度。如果单纯增大输入信息规模,神经网络的规模很快就会超过硬件的承受能力。
1.3 传统方案的困境:信息传递中的稀释问题
在Transformer出现前,主流解决方案是使用循环神经网络(RNN)风格的序列处理:
工作方式: 1. 只输入当前词汇,让其流经神经网络,输出预测的下一词汇 2. 将神经网络输出层的前一层的数值提取出来(这些中间层数值包含输入数据的信息) 3. 在预测之后的词汇时,把下一个词和之前提取的中间层数据都输入神经网络 4. 持续这种循环,中间的中间层数值作为某种压缩,用来包含之前全部输入词汇的信息
核心问题:早期输入神经网络的信息在中间层的传递中会被严重稀释。
由于每一次中间层数值都会伴随当前输入词汇一起输入,每一步输入中间层数值里包含的上文信息都会被当前输入词汇的信息在神经网络的传递中挤压。中间层节点能储存的信息总量是固定的,之前的上文信息就会逐渐被压缩,而且越早输入的词汇被压缩得越严重。当我们非常依赖早期的前文信息进行预测时,却发现早期的前文信息已经快被压缩没了。
二、Attention机制的核心原理
2.1 结构创新:三组节点设计
Attention机制对神经网络进行了关键修改,在输入层节点后插入一种新的上下文压缩机制,包含三组分别与输入层节点连接的节点:
- Query(查询)节点组:用于查询的节点组
- Key(键)节点组:用于被查询的节点组
- 普通中间层节点组:与普通神经网络中间层没什么两样
2.2 Attention计算流程
输入阶段:将输入层的数值分别经过对应的层间计算,传递到这三组节点上。
计算阶段: 1. 取出Query节点组的数值和Key节点组的数值 2. 这些数值不会直接用于后续传递,而是直接对应相乘,然后加起来得到一个权重数值 3. 用这个权重数值去乘以普通中间层的数值
输出阶段: 1. Query和Key节点组的使命结束 2. 将输入层、被计算过权重的普通中间层与后续的层连上 3. 让后续层同时接受输入信息和被计算过权重的普通中间层的信息 4. 流经网络输出预测结果
2.3 多序列协同推理
在实际推理过程中,使用多个完全一样的神经网络副本协同工作:
- 将所有上文词汇分别输入各自网络的Attention层
- 用当前词汇的Query节点组数值分别去与各网络中Key节点组的数值相乘并相加
- 分别拿权重乘以对应的普通中间层的数值
- 把两份被乘以权重的普通中间层数值叠加起来,作为最终的中间层数值
- 传入后续神经网络完成推理
2.4 关联度计算与信息加权
计算结果代表当前输入词汇的推理任务与对应词汇信息的相关程度: - 数值越大,关联程度越高 - 用相关程度乘以它们的普通中间层的数值,然后叠加起来传入神经网络完成推理
关键优势:通过让每个上文词汇的Key节点组数值都有均等的机会与当前词汇的Query节点组数值进行计算,有效避免了前文的信息被不断压缩稀释的问题。
例如,当我们需要使用第一个输入词汇"猫"的信息进行下文推理时,由于"猫"的信息与当前输入词汇的推理任务的关联度较高,当前词汇的Query节点组数值和"猫"的Key节点组数值的计算结果高达0.9,远大于其他词汇的计算结果。来自"猫"的信息在新的对上文信息的压缩总结中被赋予了更高的比重。虽然"猫"这一词是最早输入模型的,但仍然能够被有效利用起来。
三、Attention机制的自我训练演化
3.1 训练过程
神经网络训练的基本流程: 1. 同时把要输入的内容和要得到的内容放在神经网络两端 2. 让输入的内容流经网络得到预测结果 3. 对比神经网络预测结果和目标结果 4. 把差别回传到网络中,让网络各个节点根据差异修正自己的参数 5. 使得在下一次预测的结果更贴近目标结果
3.2 为什么Query和Key会"老老实实"工作
Attention训练的核心驱动逻辑:
由于预测的最后数值是从包含当前输入词汇信息的输入层节点和包含了对所有上文词汇信息的压缩的中间层向后面的神经网络传递的,所以这部分网络的连线所代表的权重会很自然地被更新,使得下次预测结果差异更小。
而中间层的数值本身是由各个词汇网络的普通中间层乘以对应的权重数值得到的,所以要使得下次预测的差异更小,就要让与当前输入词汇的推理任务有关联的词汇的中间层权重更高,这样有关联的词汇的信息才能更多地被包含在压缩的上文信息中。
以"猫喜欢抓老鼠"为例: - 对于当前输入"抓"的推理任务来说,要成功预测出下一个词"老鼠",明显需要获得更多来自上文"猫"的信息 - 要增加"猫"的信息,必须增加"猫"的Attention权重 - "猫"的Attention权重是由"抓"的Query节点组和"猫"的Key节点组的数值计算得到的 - 因此,输入节点和Query节点组与输入节点和Key节点组间的连线权重必须按照能够让Query节点组和Key节点组的数值在计算后能得到更大Attention权重数值的方向进行改变
3.3 自发演化方向
在这种驱动下,Query节点组和Key节点组会自然而然地朝着专门用来查询和用来被查询的方向演化,而不是朝着变成一个普通中间层的方向演化。这样它们才能够在词汇有关联时给出能够让Attention权重计算结果更大的数值,在词汇没有关联时给出能够让Attention权重计算结果更小的数值。
按照这种逻辑,可以对输入层和Attention层之间的连线权重进行更新。由于三个神经网络都是对同一个含Attention机制的神经网络的复制体,需要对每个网络的Attention层都进行更新,然后把更新一起平均起来,作为对神经网络Attention输入层和Attention层之间的连线权重的最终更新。
四、KV Cache优化
4.1 计算冗余问题
在动画演示中,每次输入一个词汇来推理下一个词汇时,都需要把前文所有词汇都输入神经网络,然后流经Attention层来得到Key节点组和普通中间层节点。很明显这种计算是重复浪费的。
4.2 缓存机制
在实际的包含Attention机制的神经网络中,这些前文词汇的Key节点组和普通中间层节点数值都可以被存起来,当前输入词汇的Query节点组直接与这些存起来的Key节点组和普通中间层节点的数值进行计算和加权就行了。
这样每次推理时,数据都只需要流经当前的神经网络,能大幅节约计算资源。这种优化手段一般被叫做KV Cache(Key Value Cache)。
五、Transformer的核心贡献
通过上面的机制,我们成功实现了在不扩充神经网络规模的前提下,有效地利用起所有上文信息来让模型做出正确预测。这也是Transformer模型的核心思想。
Transformer的核心创新: 1. 并行处理:突破了RNN序列处理的限制,可以并行处理整个序列 2. 直接关联:通过Attention机制,让任意两个位置之间可以直接建立关联,无需通过中间层传递 3. 信息保留:有效解决了早期信息在传递过程中被稀释的问题
六、面试题与答案
题目1:为什么我们需要Attention机制?
答案:
主要有两个原因:
信息规模与计算资源的矛盾:增大输入信息规模时,神经网络规模需要更大,而神经网络规模的增长速度远快于输入信息的增长速度,很快就会超过硬件承受能力。
传统RNN的信息稀释问题:在传统循环结构中,早期输入的信息在中间层传递时会被后续输入的信息不断挤压和稀释,越早输入的信息被压缩得越严重,导致模型无法有效利用早期上下文信息。
Attention机制通过Query-Key-Value的设计,让模型能够直接建立任意位置之间的关联,有效解决了信息稀释问题。
题目2:描述Attention机制的核心计算流程。
答案:
Attention机制的核心计算流程如下:
- 三组节点:输入层后插入三组节点——Query(查询)、Key(键)、普通中间层
- 信息传递:输入层数值经过计算分别传递到三组节点
- 权重计算:Query节点的数值与对应Key节点的数值相乘并相加,得到权重
- 信息加权:用权重乘以普通中间层的数值
- 输出融合:结合原始输入和加权后的中间层信息,传递给后续网络
核心公式为:$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$
题目3:Query、Key、Value分别代表什么含义?
答案:
- Query(查询):代表当前token想要查询什么信息,可以理解为"我需要什么信息"
- Key(键):代表每个token的索引标识,可以理解为"我有什么信息"
- Value(值):代表每个token的实际内容信息
Attention机制的本质是:用Query去和所有Key做匹配(计算相似度),然后根据匹配程度(权重)对相应的Value进行加权求和。匹配度越高的Value,在最终结果中占比越大。
题目4:为什么Query和Key节点组会自然地朝着"查询"和"被查询"的方向演化,而不是变成普通中间层?
答案:
这是由训练目标的驱动造成的:
- 最终目标:使预测结果与真实结果的差异最小化
- 中间层的作用:中间层数值由普通中间层乘以Attention权重得到,要减少预测差异,就需要让与当前任务关联的词汇获得更高的Attention权重
- 权重的来源:Attention权重由Query和Key的数值计算得到
- 训练压力:因此连接输入层与Query/Key的权重必须朝着"在词汇有关联时产生更大权重,无关联时产生更小权重"的方向调整
- 自然分化:这种训练压力使得Query专门负责"查询",Key专门负责"被查询",而非演化成普通的中间层
题目5:Attention机制如何解决RNN的信息稀释问题?
答案:
RNN的信息稀释问题根源在于:中间层容量固定,新增信息会挤压旧信息。
Attention的解决思路:
- 直接关联:Query可以直接与任意位置的Key计算关联度,无需通过层层传递
- 均等机会:每个上文词汇的Key都有均等机会与当前Query计算
- 按需加权:关联度高的词汇自动获得更高权重,无论它在序列中的位置多早
- 并行处理:可以同时计算所有位置对的关联度,避免顺序传递中的信息损失
举例:即使"猫"是序列中第一个词,当"抓"需要它的信息时,可以通过Query-Key计算直接获取高达0.9的关联权重,而不会因为经历了多次信息传递被稀释掉。
题目6:什么是KV Cache?它的作用是什么?
答案:
KV Cache(Key Value Cache)是一种推理优化技术:
作用:将之前计算过的Key和Value缓存起来,避免重复计算。
工作流程: - 第一次输入"你是谁?"时,计算4个token的K、V并缓存 - 生成下一个token时,只需计算新token的K、V,与缓存的K、V拼接 - 无需重新计算"你是谁?"的K、V
价值:大幅减少重复计算,提升推理速度,本质是用显存空间换计算时间。
题目7:Transformer相比传统RNN有哪些核心优势?
答案:
| 方面 | RNN | Transformer |
|---|---|---|
| 计算方式 | 顺序串行 | 可并行处理 |
| 信息传递 | 通过隐藏层逐层传递 | 直接通过Attention |
| 长期依赖 | 距离越远衰减越严重 | 直接建立任意位置关联 |
| 训练效率 | 受序列长度限制大 | 可高效并行训练 |
| 信息保留 | 早期信息易被稀释 | 各位置信息等权参与 |
Transformer的核心优势在于:突破序列长度的限制,直接建模任意距离的依赖关系,同时支持高效并行计算。
题目8:Attention计算中为什么要除以$\sqrt{d_k}$?
答案:
除以$\sqrt{d_k}$(d_k是Key向量的维度)主要有以下原因:
控制方差:Q和K的点积结果方差会随维度增大而增大,除以$\sqrt{d_k}$可以使得点积结果的方差稳定在1左右
避免softmax饱和:如果点积结果数值过大,经过softmax后会产生接近one-hot的分布(梯度趋近于0),导致训练困难
数值稳定:保持注意力权重的分布合理,梯度的尺度恰当,有利于模型训练
这是一个经验性的调优技巧,帮助Transformer在各种维度下都能稳定训练。
题目9:多头注意力(MHA)的作用是什么?为什么需要多头?
答案:
单头注意力的局限:只能捕获一种类型的关联关系
多头注意力的优势: 1. 多子空间学习:不同的头可以学习关注不同的语义关联 2. 多语义捕获:有的头可能关注语法关系,有的关注语义相似性,有的关注位置关系等 3. 增强表达能力:多个头的输出拼接后,通过线性变换整合,提供更丰富的表示
直觉理解:就像人类理解一段话时可以从多个角度分析,多头注意力让模型也能并行地从多个角度捕获信息。
题目10:描述Attention在Transformer中的位置和作用。
答案:
在Transformer的Encoder和Decoder中,Attention处于核心位置:
Encoder结构:N个相同的层堆叠,每层包含: - Multi-Head Self-Attention(自注意力) - Feed-Forward Network(前馈网络)
Decoder结构:每层包含: - Masked Multi-Head Self-Attention(带掩码的自注意力) - Encoder-Decoder Attention(编码器-解码器注意力) - Feed-Forward Network
作用: 1. Self-Attention:让序列内任意位置直接建立关联,捕获上下文依赖 2. Cross-Attention:Decoder通过Query与Encoder的Key-Value交互,获取输入序列的信息 3. 信息整合:将分散在序列各处的相关信息聚合到当前表示中