import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, d_model):
        super().__init__(vocab_size, d_model, padding_idx=1)


class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len, device):
        super().__init__()
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False

        pos = torch.arange(0, max_len, device=device)
        pos = pos.float().unsqueeze(1)
        _2i = torch.arange(0, d_model, 2, device=device)

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))

    def forward(self, x):
        seq_len = x.shape[1]
        return self.encoding[:seq_len, :]


class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, drop_prob, device):
        super().__init__()
        self.tok_emb = TokenEmbedding(vocab_size, d_model)
        self.pos_emb = PositionalEmbedding(d_model, max_len, device)
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x):
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(x)
        return self.dropout(tok_emb + pos_emb)


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        self.d_model = d_model
        self.n_head = n_head

        assert d_model % n_head == 0
        self.n_d = d_model // n_head

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        batch_size = q.shape[0]

        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)

        q = q.view(batch_size, -1, self.n_head, self.n_d).permute(0, 2, 1, 3)
        k = k.view(batch_size, -1, self.n_head, self.n_d).permute(0, 2, 1, 3)
        v = v.view(batch_size, -1, self.n_head, self.n_d).permute(0, 2, 1, 3)

        attn_score = q @ k.transpose(2, 3) / math.sqrt(self.n_d)

        if mask is not None:
            attn_score = attn_score.masked_fill(mask == 0, -1e9)

        attn_weight = F.softmax(attn_score, dim=-1)

        attn_output = attn_weight @ v

        attn_output = (
            attn_output.permute(0, 2, 1, 3)
            .contiguous()
            .view(batch_size, -1, self.d_model)
        )

        output = self.w_o(attn_output)

        return output


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, drop_prob=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x


class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, keepdim=True, unbiased=False)
        norm = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * norm + self.beta
        return out


class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_ff, n_head, drop_prob):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_head)
        self.drop1 = nn.Dropout(drop_prob)
        self.norm1 = LayerNorm(d_model)

        self.ffn = PositionwiseFeedForward(d_model, d_ff)
        self.drop2 = nn.Dropout(drop_prob)
        self.norm2 = LayerNorm(d_model)

    def forward(self, x, s_mask=None):
        _x = x
        x = self.attention(x, x, x, s_mask)
        x = self.drop1(x)
        x = self.norm1(x + _x)

        _x = x
        x = self.ffn(x)
        x = self.drop2(x)
        x = self.norm2(x + _x)
        return x


class DecoderLayer(nn.Module):
    def __init__(self, d_model, d_ff, n_head, drop_prob):
        super().__init__()
        self.attention1 = MultiHeadAttention(d_model, n_head)
        self.drop1 = nn.Dropout(drop_prob)
        self.norm1 = LayerNorm(d_model)

        self.cross_attention = MultiHeadAttention(d_model, n_head)
        self.drop2 = nn.Dropout(drop_prob)
        self.norm2 = LayerNorm(d_model)

        self.ffn = PositionwiseFeedForward(d_model, d_ff)
        self.drop3 = nn.Dropout(drop_prob)
        self.norm3 = LayerNorm(d_model)

    def forward(self, dec, enc, t_mask, s_mask):
        _dec = dec
        dec = self.attention1(dec, dec, dec, t_mask)
        dec = self.drop1(dec)
        dec = self.norm1(dec + _dec)

        _dec = dec
        dec = self.cross_attention(dec, enc, enc, s_mask)
        dec = self.drop2(dec)
        dec = self.norm2(dec + _dec)

        _dec = dec
        dec = self.ffn(dec)
        dec = self.drop3(dec)
        dec = self.norm3(dec + _dec)
        return dec


class Encoder(nn.Module):
    def __init__(
        self, enc_voc_size, max_len, d_model, d_ff, n_head, n_layer, drop_prob, device
    ):
        super().__init__()
        self.embedding = TransformerEmbedding(
            enc_voc_size, d_model, max_len, drop_prob, device
        )
        self.layers = nn.ModuleList(
            [EncoderLayer(d_model, d_ff, n_head, drop_prob) for _ in range(n_layer)]
        )

    def forward(self, x, s_mask):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, s_mask)
        return x


class Decoder(nn.Module):
    def __init__(
        self, dec_voc_size, max_len, d_model, d_ff, n_head, n_layer, drop_prob, device
    ):
        super().__init__()
        self.embedding = TransformerEmbedding(
            dec_voc_size, d_model, max_len, drop_prob, device
        )
        self.layers = nn.ModuleList(
            [DecoderLayer(d_model, d_ff, n_head, drop_prob) for _ in range(n_layer)]
        )
        self.fc = nn.Linear(d_model, dec_voc_size)

    def forward(self, dec, enc, t_mask, s_mask):
        dec = self.embedding(dec)
        for layer in self.layers:
            dec = layer(dec, enc, t_mask, s_mask)
        dec = self.fc(dec)
        return dec


class Transformer(nn.Module):
    def __init__(
        self,
        src_pad_idx,
        trg_pad_idx,
        enc_voc_size,
        dec_voc_size,
        max_len,
        d_model,
        d_ff,
        n_head,
        n_layer,
        drop_prob,
        device,
    ):
        super().__init__()
        self.encoder = Encoder(
            enc_voc_size, max_len, d_model, d_ff, n_head, n_layer, drop_prob, device
        )
        self.decoder = Decoder(
            dec_voc_size, max_len, d_model, d_ff, n_head, n_layer, drop_prob, device
        )
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx):
        q_len, k_len = q.shape[1], k.shape[1]
        q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3).repeat(1, 1, 1, k_len)
        k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2).repeat(1, 1, q_len, 1)
        mask = q & k
        return mask

    def make_causal_mask(self, q, k):
        q_len, k_len = q.shape[1], k.shape[1]
        mask = (
            torch.tril(torch.ones(q_len, k_len)).type(torch.BoolTensor).to(self.device)
        )
        return mask

    def forward(self, src, trg):
        src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx)
        trg_mask = self.make_pad_mask(
            trg, trg, self.trg_pad_idx, self.trg_pad_idx
        ) * self.make_causal_mask(trg, trg)
        src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx)
        enc = self.encoder(src, src_mask)
        output = self.decoder(trg, enc, trg_mask, src_trg_mask)
        return output


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"运行设备: {device}")

    # 超参数
    src_pad_idx = 1
    trg_pad_idx = 1
    enc_voc_size = 37000
    dec_voc_size = 37000
    max_len = 512
    d_model = 512
    d_ff = 2048
    n_head = 8
    n_layer = 6
    drop_prob = 0.1

    # 初始化模型
    model = Transformer(
        src_pad_idx=src_pad_idx,
        trg_pad_idx=trg_pad_idx,
        enc_voc_size=enc_voc_size,
        dec_voc_size=dec_voc_size,
        max_len=max_len,
        d_model=d_model,
        d_ff=d_ff,
        n_head=n_head,
        n_layer=n_layer,
        drop_prob=drop_prob,
        device=device,
    ).to(device)

    # 参数量统计
    total_params = sum(p.numel() for p in model.parameters())
    print(f"标准Transformer总参数量: {total_params:,}")

    # 测试输入
    batch_size = 4
    src_len = 64
    trg_len = 64
    src = torch.randint(0, enc_voc_size, (batch_size, src_len)).to(device)
    trg = torch.randint(0, dec_voc_size, (batch_size, trg_len)).to(device)

    # 前向传播
    model.eval()
    with torch.no_grad():
        out = model(src, trg)

    # 维度校验
    print(f"src shape: {src.shape}")
    print(f"trg shape: {trg.shape}")
    print(f"output shape: {out.shape}")
    assert out.shape == (batch_size, trg_len, dec_voc_size)
    print("✅ 标准参数 Transformer 前向传播测试通过")