def forward(self, query, key, value, mask=None):
batch_size = query.size(0) #获取 batch 的大小
# Linear projections
q = self.w_q(query) # [batch_size, seq_len, d_model]
k = self.w_k(key) # [batch_size, seq_len, d_model]
v = self.w_v(value) # [batch_size, seq_len, d_model]
# Split the embeddings into multiple heads
q = q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # [batch_size, n_heads, seq_len, d_k]
k = k.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # [batch_size, n_heads, seq_len, d_k]
v = v.view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2) # [batch_size, n_heads, seq_len, d_v]
python
运行
def forward(self, query, key, value, mask=None):
batch_size = query.size(0) #获取 batch 的大小
解析:获取批次大小(batch_size)
query.size(0):获取query张量的第 0 个维度,也就是「批次大小」(比如你之前设置的batch_size=8)。- 作用:后续所有操作(拆分多头、维度变换)都需要用到
batch_size,确保每一步的张量形状匹配,避免维度报错。 - 通俗理解:
batch_size就是 “一次训练的样本数量”,比如一次喂给模型 8 个样本,batch_size就等于 8。
python
运行
# Linear projections(线性投影)
q = self.w_q(query) # [batch_size, seq_len, d_model]
k = self.w_k(key) # [batch_size, seq_len, d_model]
k = self.w_k(key) # [batch_size, seq_len, d_model]
v = self.w_v(value) # [batch_size, seq_len, d_model]
解析:线性投影(核心步骤)
前提:
self.w_q、self.w_k、self.w_v就是你之前定义的 3 个独立nn.Linear(d_model, d_model)层(各自有独立权重)。作用:将输入的
query、key、value(原始特征向量),通过线性变换(矩阵乘法 + 偏置),映射到不同的子空间,生成适配多头注意力计算的特征向量。形状说明(注释已标注,补充理解):
- 输入
query/key/value:形状默认是[batch_size, seq_len, d_model](批次大小、序列长度、模型维度); - 输出
q/k/v:形状不变,还是[batch_size, seq_len, d_model]—— 线性投影只改变特征向量的 “内容”(映射到新子空间),不改变张量的维度。
- 输入
关键:这一步是为后续 “拆分多头” 做准备,将原始特征转换成适合注意力计算的格式。
python
运行
# Split the embeddings into multiple heads(将嵌入向量拆分为多个头)
q = q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # [batch_size, n_heads, seq_len, d_k]
k = k.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # [batch_size, n_heads, seq_len, d_k]
v = v.view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2) # [batch_size, n_heads, seq_len, d_v]
解析:拆分多头(多头注意力的核心操作)
这一步是将线性投影后的q、k、v,拆分成多个 “注意力头”,让模型能同时关注不同的特征,步骤拆解 + 通俗理解:
1. 先看核心函数:view 和 transpose
view(shape):重塑张量形状,不改变张量内的数据,只改变 “查看方式”(比如把一个 [8,20,64] 的张量,改成 [8,20,2,32]);- 这里的
-1:表示 “自动计算该维度的大小”,避免手动计算出错(比如q的形状是 [8,20,64],view(8, -1, 2, 32),-1会自动计算为 20)。
- 这里的
transpose(1, 2):交换张量的第 1 个和第 2 个维度(对应代码中的seq_len和n_heads维度),调整维度顺序适配注意力计算。
2. 逐行拆解(结合你定义的参数)
假设你设置:n_heads=2(注意力头数)、d_model=64(模型维度)、d_k=d_v=32(每个头的维度,d_model ÷ n_heads),seq_len=20、batch_size=8:
第一步:
q.view(batch_size, -1, self.n_heads, self.d_k)把
q(形状 [8,20,64])重塑为[8, 20, 2, 32]:- 8:batch_size(批次大小);
- 20:seq_len(序列长度);
- 2:n_heads(注意力头数);
- 32:d_k(每个注意力头的维度)。
第二步:
.transpose(1, 2)交换第 1 个维度(20,seq_len)和第 2 个维度(2,n_heads),最终形状变为
[8, 2, 20, 32],也就是注释中的[batch_size, n_heads, seq_len, d_k]。
3. 为什么要拆分多头?
通俗理解:一个注意力头只能关注序列的某一种特征(比如一个头关注语法,一个头关注语义),拆分多个头,能让模型同时关注序列的多种特征,提升注意力机制的效果。
4. 关键注意事项
d_model必须能被n_heads整除(比如 64÷2=32),否则view重塑时会报错(无法均匀拆分每个头的维度);q和k的每个头维度都是d_k(确保后续注意力分数计算时,矩阵乘法维度匹配);v的每个头维度是d_v(可与d_k相同,也可不同,最终会通过线性层统一维度);- 拆分后,
q、k、v的形状完全匹配,可直接进入下一步 “注意力分数计算”。