def forward(self, query, key, value, mask=None):

	batch_size = query.size(0) #获取 batch 的大小
	
	# Linear projections
	
	q = self.w_q(query) # [batch_size, seq_len, d_model]
	
	k = self.w_k(key) # [batch_size, seq_len, d_model]
	
	v = self.w_v(value) # [batch_size, seq_len, d_model]
	
	# Split the embeddings into multiple heads
	
	q = q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # [batch_size, n_heads, seq_len, d_k]
	
	k = k.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # [batch_size, n_heads, seq_len, d_k]
	
	v = v.view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2) # [batch_size, n_heads, seq_len, d_v]

python

运行

def forward(self, query, key, value, mask=None):
    batch_size = query.size(0) #获取 batch 的大小

解析:获取批次大小(batch_size)

  • query.size(0):获取query张量的第 0 个维度,也就是「批次大小」(比如你之前设置的batch_size=8)。
  • 作用:后续所有操作(拆分多头、维度变换)都需要用到batch_size,确保每一步的张量形状匹配,避免维度报错。
  • 通俗理解:batch_size就是 “一次训练的样本数量”,比如一次喂给模型 8 个样本,batch_size就等于 8。

python

运行

# Linear projections(线性投影)
q = self.w_q(query)  # [batch_size, seq_len, d_model]
k = self.w_k(key)    # [batch_size, seq_len, d_model]
k = self.w_k(key)    # [batch_size, seq_len, d_model]
v = self.w_v(value)  # [batch_size, seq_len, d_model]

解析:线性投影(核心步骤)

  1. 前提:self.w_qself.w_kself.w_v 就是你之前定义的 3 个独立nn.Linear(d_model, d_model)层(各自有独立权重)。

  2. 作用:将输入的querykeyvalue(原始特征向量),通过线性变换(矩阵乘法 + 偏置),映射到不同的子空间,生成适配多头注意力计算的特征向量。

  3. 形状说明(注释已标注,补充理解):

    • 输入query/key/value:形状默认是 [batch_size, seq_len, d_model](批次大小、序列长度、模型维度);
    • 输出q/k/v:形状不变,还是 [batch_size, seq_len, d_model]—— 线性投影只改变特征向量的 “内容”(映射到新子空间),不改变张量的维度。
  4. 关键:这一步是为后续 “拆分多头” 做准备,将原始特征转换成适合注意力计算的格式。

python

运行

# Split the embeddings into multiple heads(将嵌入向量拆分为多个头)
q = q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)  # [batch_size, n_heads, seq_len, d_k]
k = k.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)  # [batch_size, n_heads, seq_len, d_k]
v = v.view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)  # [batch_size, n_heads, seq_len, d_v]

解析:拆分多头(多头注意力的核心操作)

这一步是将线性投影后的qkv,拆分成多个 “注意力头”,让模型能同时关注不同的特征,步骤拆解 + 通俗理解:

1. 先看核心函数:view 和 transpose

  • view(shape):重塑张量形状,不改变张量内的数据,只改变 “查看方式”(比如把一个 [8,20,64] 的张量,改成 [8,20,2,32]);

    • 这里的 -1:表示 “自动计算该维度的大小”,避免手动计算出错(比如q的形状是 [8,20,64],view(8, -1, 2, 32)-1会自动计算为 20)。
  • transpose(1, 2):交换张量的第 1 个和第 2 个维度(对应代码中的seq_lenn_heads维度),调整维度顺序适配注意力计算。

2. 逐行拆解(结合你定义的参数)

假设你设置:n_heads=2(注意力头数)、d_model=64(模型维度)、d_k=d_v=32(每个头的维度,d_model ÷ n_heads),seq_len=20batch_size=8

  • 第一步:q.view(batch_size, -1, self.n_heads, self.d_k)

    q(形状 [8,20,64])重塑为 [8, 20, 2, 32]

    • 8:batch_size(批次大小);
    • 20:seq_len(序列长度);
    • 2:n_heads(注意力头数);
    • 32:d_k(每个注意力头的维度)。
  • 第二步:.transpose(1, 2)

    交换第 1 个维度(20,seq_len)和第 2 个维度(2,n_heads),最终形状变为 [8, 2, 20, 32],也就是注释中的 [batch_size, n_heads, seq_len, d_k]

3. 为什么要拆分多头?

通俗理解:一个注意力头只能关注序列的某一种特征(比如一个头关注语法,一个头关注语义),拆分多个头,能让模型同时关注序列的多种特征,提升注意力机制的效果。

4. 关键注意事项

  • d_model 必须能被 n_heads 整除(比如 64÷2=32),否则view重塑时会报错(无法均匀拆分每个头的维度);
  • qk的每个头维度都是d_k(确保后续注意力分数计算时,矩阵乘法维度匹配);
  • v的每个头维度是d_v(可与d_k相同,也可不同,最终会通过线性层统一维度);
  • 拆分后,qkv的形状完全匹配,可直接进入下一步 “注意力分数计算”。