profileName: youpingfang postId: 218 postType: post categories:

- 6

简介

课代表: - 01:07 推理阶段的显存之权重 - 04:26 推理阶段的显存之 kv cache - 06:43 训练阶段的显存之静态显存 - 08:52 训练阶段的显存之动态显存 - 11:43 强化学习的显存 - 13:24 LoRA 的显存 - 15:19 MoE 模型的显存


核心要点总结

一、基础概念

数据单位

单位 换算
1 字节 (Byte) 8 位 (bit)
1 KB 1024 字节
1 MB 1024 KB ≈ 100 万字节
1 GB 1024 MB ≈ 10 亿字节
1 B (参数量) 10 亿参数

精度与字节数

精度 位数 占字节
FP 32 32 4
BF 16 / FP 16 16 2(主流)
INT 8 8 1
INT 4 4 0.5

二、推理阶段显存

推理显存 = 权重 + KV Cache

1. 权重显存

  • 公式:参数量 × 每参数字节数

  • 例:7 B 模型,BF 16(2 字节) → 7 × 10⁹ × 2 = 14 GB

    - 经验规律:参数量 × 2 ≈ 推理权重显存(GB)

2. KV Cache 显存

  • 公式:2 × L × H × 2 ^1
    • 第一个 2:K 和 V 两部分
    • L:层数(如 7 B 模型约 32 层)
    • H:隐藏层维度(如 7 B 模型约 4096)
    • 最后一个 2:BF 16 占 2 字节
  • 单 token 所需:2 × 32 × 4096 × 2 = 0.5 MB
  • 实际场景:4096 个 token 时 → 0.5 MB × 4096 ≈ 2 GB;batch size[^2]=2 时 → 4 GB

3. 推理显存估算(7 B 模型)

  • 权重 14 GB + KV Cache 约 2-4 GB ≈ 16-18 GB
  • 单张 RTX 4090(24 GB)可轻松容纳

三、训练阶段显存

训练显存 = 静态显存 + 动态显存

1. 静态显存(必须保留)

组件 精度 大小
权重 (Weight) BF 16 14 GB
梯度 (Gradient) BF 16 14 GB
优化器状态 (Adam) - 权重备份 FP 32 28 GB
优化器状态 - 一阶动量 (m) FP 32 28 GB
优化器状态 - 二阶动量 (v) FP 32 28 GB
合计 112 GB

经验公式:静态显存 ≈ 参数量 × 16(7 B × 16 = 112 GB)

2. 动态显存(激活值)

  • 4096 token 输入时约需 40 GB
  • 推导较复杂,直接记量级即可

3. 训练总显存估算

  • 112 GB + 40 GB = 152 GB
  • A 100 单卡 80 GB 无法容纳
  • 至少需要 2 张 A 100(160 GB)

四、优化手段

梯度检查点 (Gradient Checkpointing)

  • 原理:训练时不保存所有层的激活值,反向传播时重新计算
  • 效果:将激活值显存从 40 GB 降至约 5-10 GB
  • 代价:训练时间增加约 20-30%
  • 适用:显存紧张但可接受训练时间延长的场景

五、三种特殊场景

1. 强化学习(PPO)

PPO 算法包含 4 个模型:

模型 训练/推理 静态显存
Actor 模型 需要训练 112 GB
Critic 模型 需要训练 112 GB
Reference 模型 仅推理 14 GB
Reward 模型 仅推理 14 GB
合计 252 GB

另有训练过程中的 KV Cache 显存占用,PPO 是显存炸弹中的炸弹

2. LoRA 微调

  • 核心思想:冻结预训练权重,只更新少量低秩适配矩阵
  • 参数量级:仅为原模型的 0.01% ~ 0.1%(7 B 模型约 7 M 参数)
  • 训练显存:即使 7 B 模型 + BF 16 + LoRA,训练总显存可降至 20-40 GB
  • 结论:单张 A 100(80 GB)足矣

3. MoE 模型(如 Mixtral 8×7 B)

  • 总参数量 32 B,但每次只激活 3 B(每个 token 走部分专家)
  • 常见误区:认为显存也按激活量计算(错误)
  • 正确理解:32 B 参数必须全部加载到显存(随时待命)
  • 训练时更耗显存:需要额外的负载均衡损失防止专家坍缩

⚠️ 错误与不精确之处说明

  1. 视频中算错总显存:原文 112 + 40 = 156 GB,正确应为 112 + 40 = 152 GB。不过这不影响整体结论(2 张 A 100 足够,但很紧张)。
  2. LoRA 参数量比例:视频称 1/7000(约万分之 1.4),实际 LoRA 一般在 0.1%~1%(千分之一到百分之一)量级。两者差了一个数量级,LoRA 的实际可训练参数量远大于视频所述。
  3. 推理阶段 KV Cache 计算简略:视频假设 batch size 为 2 时 KV Cache 线性翻倍(2 → 4 GB),但实际中还受分组查询注意力(GQA) 等因素影响,KV Cache 增长略小于线性。这一细节在高阶面试中可能被追问。
  4. 激活值计算过于简化:视频直接给出 40 GB 而未提供推导,实际激活值与 batch size、序列长度、模型宽度、深度均密切相关,在不同配置下差异巨大。

面试题与参考答案

Q 1:用 FP 16/BF 16 推理一个 7 B 模型需要多少显存?

参考答案:约 14-18 GB。其中权重占用 7 × 10⁹ × 2 ≈ 14 GB;KV Cache 按 4096 tokens 计算约 2 GB,batch size 增大或序列变长会相应增加。单张 RTX 4090(24 GB)或 RTX 3090(24 GB)均可容纳。

Q 2:全量微调一个 7 B 模型需要多少显存?单张 A 100(80 GB)够吗?

参考答案:不够。静态显存(权重+梯度+优化器)约 112 GB(经验公式:参数量 × 16),动态显存(激活值)约 40 GB,合计约 152 GB。单张 A 100(80 GB)无法容纳,至少需要 2 张 A 100(合计 160 GB)才能完整进行全量微调。如果使用梯度检查点(Gradient Checkpointing),可将激活值显存压缩至 5-10 GB,此时单卡勉强可用但训练速度大幅下降。

Q 3:LoRA 为什么能大幅降低训练显存?

参考答案:LoRA 冻结预训练权重(不需要计算梯度、优化器状态),只更新少量低秩矩阵。首先,原模型的权重仍占用 14 GB 显存(推理时也需要),但不再需要梯度(+14 GB)和优化器状态(+84 GB)。其次,可训练参数只有约 0.1%,对应的梯度、优化器状态仅需数十 MB。因此,7 B 模型 + LoRA 的总训练显存从 152 GB 降至约 20-40 GB,单张 A 100 轻松容纳。

Q 4:MoE 模型的总参数量是 32 B,但每次只激活 3 B,显存是不是只要 3 B 对应的量?

参考答案不是。这是一个常见误区。MoE 虽然每次计算只激活部分专家,但所有专家参数都必须加载到显存中待命,因为不同的 token 会走不同的专家路径。因此,32 B 模型需要的显存与相同总参数量的 Dense 模型一致(约 64 GB 权重 + KV Cache)。MoE 节省的是计算量(FLOPs),而非显存占用。

Q 5:KV Cache 为什么会成为强化学习(PPO)的显存炸弹?

参考答案:PPO 训练中,Actor 模型需要反复进行推理(采样)来生成训练数据。每次推理都会产生 KV Cache,且 PPO 中通常包含 4 个模型(Actor、Critic、Reference、Reward),其中 Actor 和 Reference 都需要推理,产生的 KV Cache 数量是常规推理的 2-3 倍。此外,PPO 需要同时保存大量采样数据的 KV Cache 用于后续训练更新,进一步加剧显存压力。这是 PPO 训练成本极高的原因之一。

Q 6:梯度检查点(Gradient Checkpointing)的原理是什么?有无代价?

参考答案:梯度检查点通过空间换时间来降低显存。训练时只保存部分层的激活值,不在前向过程中保存所有层的激活结果,而是在反向传播过程中重新计算缺失的激活值。代价是训练时间增加约 20-30%(额外的前向计算消耗)。适用于显存紧张、对训练时间不敏感的场景,如单卡微调大模型。

## KV Cache 公式解析:2 × L × H × 2

### 公式含义拆解

| 因子 | 含义 | 7 B 模型取值 |
|------|------|-------------|
| **第一个 2** | K(Key)和 V(Value)**两个**矩阵 | 固定 |
| **L** | 层数(Layers) | 32 层 |
| **H** | 隐藏层维度(Hidden dimension) | 4096 |
| **最后一个 2** | BF 16 精度,每个数占 **2 字节** | 固定 |

### 直观理解

```
┌─────────────────────────────────────────────┐
│              单个 Token 的 KV Cache           │
├─────────────────────────────────────────────┤
│  L 层 × (K向量 4096维 + V向量 4096维)         │
│  = 32 × (4096 + 4096) 个数字                 │
│  = 32 × 8192 个数字                          │
│  × 2 字节(BF16)                            │
│  = 524,288 字节 ≈ 0.5 MB                     │
└─────────────────────────────────────────────┘
```

### 举例:一批 4096 tokens 的 KV Cache

```
0.5 MB × 4096 ≈ 2 GB
```

- 每个 token 独立缓存一份 K 和 V
- 4096 个 token → 显存线性增长

### 一句话总结

> **公式 = 2(K+V)× 层数 × 维度 × 2 字节**,算的是一个 token 在所有层、所有 K/V 向量上占用的显存。

### 为什么是 H × H?

KV Cache 本质上是 attention 中的 **Q、K、V** 矩阵运算:

```
Attention = softmax(Q × K^T / √d) × V
```

- 每个 token 的 K 和 V 向量,维度都是 **H = 4096**
- 整个序列的 K 矩阵 = `seq_len × H`,V 矩阵 = `seq_len × H`
- 每个 token 单独缓存时占:`2 × H` 个数字(K + V 各一个向量)

因此单 token KV Cache = `2 × H × 2 字节` = `2 × 4096 × 2`,再乘以层数 L 得到 `2 × L × H × 2`。

[^2]: ## Batch Size(批次大小)是什么

**Batch Size** 指模型**一次性同时处理的样本数量**。

### 以 KV Cache 的例子理解

你的摘录中:

```
单 token 所需:2 × 32 × 4096 × 2 ≈ 0.5 MB
实际场景:4096 个 token 时 → 0.5 MB × 4096 ≈ 2 GB
batch size=2 时 → 4 GB
```

**单独 tokens**:一个样本由 4096 个 token 组成,KV Cache 占 2 GB。

**batch size=2**:同时处理 **2 个这样的样本**(每个样本 4096 tokens),KV Cache 需要为 2 个独立的对话各自缓存,因此翻倍为 4 GB。

### 类比理解

- **batch size=1**:一次只看一个病人 → 只读一份病历
- **batch size=4**:一次看四个病人 → 同时摊开四份病历

### 在训练和推理中的影响

| 场景 | batch size 越大 → |
|------|-------------------|
| 训练 | 梯度更新更稳定,但显存占用线性增长 |
| 推理 | 吞吐量提高,但 KV Cache 增长,延迟可能增加 |

### 典型值参考

| 场景 | 常见 batch size |
|------|----------------|
| 单卡推理 | 1~8 |
| 单卡训练(7 B) | 1~4(受显存限制) |
| 多卡训练(70 B) | 每卡 1~4,总 batch 可达 64~256 |

> **一句话**:batch size = 一次送进 GPU 的样本数,越大越“吃”显存,但训练梯度更稳、推理吞吐更高。