import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class TokenEmbedding(nn.Embedding):
def __init__(self, vocab_size, d_model):
super().__init__(vocab_size, d_model, padding_idx=1)
class PositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len, device):
super().__init__()
self.encoding = torch.zeros(max_len, d_model, device=device)
self.encoding.requires_grad = False
pos = torch.arange(0, max_len, device=device)
pos = pos.float().unsqueeze(1)
_2i = torch.arange(0, d_model, 2, device=device)
self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
def forward(self, x):
seq_len = x.shape[1]
return self.encoding[:seq_len, :]
class TransformerEmbedding(nn.Module):
def __init__(self, vocab_size, d_model, max_len, drop_prob, device):
super().__init__()
self.tok_emb = TokenEmbedding(vocab_size, d_model)
self.pos_emb = PositionalEmbedding(d_model, max_len, device)
self.dropout = nn.Dropout(drop_prob)
def forward(self, x):
tok_emb = self.tok_emb(x)
pos_emb = self.pos_emb(x)
return self.dropout(tok_emb + pos_emb)
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_head):
super().__init__()
self.d_model = d_model
self.n_head = n_head
assert d_model % n_head == 0
self.n_d = d_model // n_head
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
batch_size = q.shape[0]
q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
q = q.view(batch_size, -1, self.n_head, self.n_d).permute(0, 2, 1, 3)
k = k.view(batch_size, -1, self.n_head, self.n_d).permute(0, 2, 1, 3)
v = v.view(batch_size, -1, self.n_head, self.n_d).permute(0, 2, 1, 3)
attn_score = q @ k.transpose(2, 3) / math.sqrt(self.n_d)
if mask is not None:
attn_score = attn_score.masked_fill(mask == 0, -1e9)
attn_weight = F.softmax(attn_score, dim=-1)
attn_output = attn_weight @ v
attn_output = (
attn_output.permute(0, 2, 1, 3)
.contiguous()
.view(batch_size, -1, self.d_model)
)
output = self.w_o(attn_output)
return output
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, drop_prob=0.1):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(drop_prob)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
return x
class LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
var = x.var(-1, keepdim=True, unbiased=False)
norm = (x - mean) / torch.sqrt(var + self.eps)
out = self.gamma * norm + self.beta
return out
class EncoderLayer(nn.Module):
def __init__(self, d_model, d_ff, n_head, drop_prob):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_head)
self.drop1 = nn.Dropout(drop_prob)
self.norm1 = LayerNorm(d_model)
self.ffn = PositionwiseFeedForward(d_model, d_ff)
self.drop2 = nn.Dropout(drop_prob)
self.norm2 = LayerNorm(d_model)
def forward(self, x, s_mask=None):
_x = x
x = self.attention(x, x, x, s_mask)
x = self.drop1(x)
x = self.norm1(x + _x)
_x = x
x = self.ffn(x)
x = self.drop2(x)
x = self.norm2(x + _x)
return x
class DecoderLayer(nn.Module):
def __init__(self, d_model, d_ff, n_head, drop_prob):
super().__init__()
self.attention1 = MultiHeadAttention(d_model, n_head)
self.drop1 = nn.Dropout(drop_prob)
self.norm1 = LayerNorm(d_model)
self.cross_attention = MultiHeadAttention(d_model, n_head)
self.drop2 = nn.Dropout(drop_prob)
self.norm2 = LayerNorm(d_model)
self.ffn = PositionwiseFeedForward(d_model, d_ff)
self.drop3 = nn.Dropout(drop_prob)
self.norm3 = LayerNorm(d_model)
def forward(self, dec, enc, t_mask, s_mask):
_dec = dec
dec = self.attention1(dec, dec, dec, t_mask)
dec = self.drop1(dec)
dec = self.norm1(dec + _dec)
_dec = dec
dec = self.cross_attention(dec, enc, enc, s_mask)
dec = self.drop2(dec)
dec = self.norm2(dec + _dec)
_dec = dec
dec = self.ffn(dec)
dec = self.drop3(dec)
dec = self.norm3(dec + _dec)
return dec
class Encoder(nn.Module):
def __init__(
self, enc_voc_size, max_len, d_model, d_ff, n_head, n_layer, drop_prob, device
):
super().__init__()
self.embedding = TransformerEmbedding(
enc_voc_size, d_model, max_len, drop_prob, device
)
self.layers = nn.ModuleList(
[EncoderLayer(d_model, d_ff, n_head, drop_prob) for _ in range(n_layer)]
)
def forward(self, x, s_mask):
x = self.embedding(x)
for layer in self.layers:
x = layer(x, s_mask)
return x
class Decoder(nn.Module):
def __init__(
self, dec_voc_size, max_len, d_model, d_ff, n_head, n_layer, drop_prob, device
):
super().__init__()
self.embedding = TransformerEmbedding(
dec_voc_size, d_model, max_len, drop_prob, device
)
self.layers = nn.ModuleList(
[DecoderLayer(d_model, d_ff, n_head, drop_prob) for _ in range(n_layer)]
)
self.fc = nn.Linear(d_model, dec_voc_size)
def forward(self, dec, enc, t_mask, s_mask):
dec = self.embedding(dec)
for layer in self.layers:
dec = layer(dec, enc, t_mask, s_mask)
dec = self.fc(dec)
return dec
class Transformer(nn.Module):
def __init__(
self,
src_pad_idx,
trg_pad_idx,
enc_voc_size,
dec_voc_size,
max_len,
d_model,
d_ff,
n_head,
n_layer,
drop_prob,
device,
):
super().__init__()
self.encoder = Encoder(
enc_voc_size, max_len, d_model, d_ff, n_head, n_layer, drop_prob, device
)
self.decoder = Decoder(
dec_voc_size, max_len, d_model, d_ff, n_head, n_layer, drop_prob, device
)
self.src_pad_idx = src_pad_idx
self.trg_pad_idx = trg_pad_idx
self.device = device
def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx):
q_len, k_len = q.shape[1], k.shape[1]
q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3).repeat(1, 1, 1, k_len)
k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2).repeat(1, 1, q_len, 1)
mask = q & k
return mask
def make_causal_mask(self, q, k):
q_len, k_len = q.shape[1], k.shape[1]
mask = (
torch.tril(torch.ones(q_len, k_len)).type(torch.BoolTensor).to(self.device)
)
return mask
def forward(self, src, trg):
src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx)
trg_mask = self.make_pad_mask(
trg, trg, self.trg_pad_idx, self.trg_pad_idx
) * self.make_causal_mask(trg, trg)
src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx)
enc = self.encoder(src, src_mask)
output = self.decoder(trg, enc, trg_mask, src_trg_mask)
return output
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"运行设备: {device}")
# 超参数
src_pad_idx = 1
trg_pad_idx = 1
enc_voc_size = 37000
dec_voc_size = 37000
max_len = 512
d_model = 512
d_ff = 2048
n_head = 8
n_layer = 6
drop_prob = 0.1
# 初始化模型
model = Transformer(
src_pad_idx=src_pad_idx,
trg_pad_idx=trg_pad_idx,
enc_voc_size=enc_voc_size,
dec_voc_size=dec_voc_size,
max_len=max_len,
d_model=d_model,
d_ff=d_ff,
n_head=n_head,
n_layer=n_layer,
drop_prob=drop_prob,
device=device,
).to(device)
# 参数量统计
total_params = sum(p.numel() for p in model.parameters())
print(f"标准Transformer总参数量: {total_params:,}")
# 测试输入
batch_size = 4
src_len = 64
trg_len = 64
src = torch.randint(0, enc_voc_size, (batch_size, src_len)).to(device)
trg = torch.randint(0, dec_voc_size, (batch_size, trg_len)).to(device)
# 前向传播
model.eval()
with torch.no_grad():
out = model(src, trg)
# 维度校验
print(f"src shape: {src.shape}")
print(f"trg shape: {trg.shape}")
print(f"output shape: {out.shape}")
assert out.shape == (batch_size, trg_len, dec_voc_size)
print("✅ 标准参数 Transformer 前向传播测试通过")