profileName: youpingfang postId: 218 postType: post categories:
- 6
简介
课代表: - 01:07 推理阶段的显存之权重 - 04:26 推理阶段的显存之 kv cache - 06:43 训练阶段的显存之静态显存 - 08:52 训练阶段的显存之动态显存 - 11:43 强化学习的显存 - 13:24 LoRA 的显存 - 15:19 MoE 模型的显存
核心要点总结
一、基础概念
数据单位
| 单位 | 换算 |
|---|---|
| 1 字节 (Byte) | 8 位 (bit) |
| 1 KB | 1024 字节 |
| 1 MB | 1024 KB ≈ 100 万字节 |
| 1 GB | 1024 MB ≈ 10 亿字节 |
| 1 B (参数量) | 10 亿参数 |
精度与字节数
| 精度 | 位数 | 占字节 |
|---|---|---|
| FP 32 | 32 | 4 |
| BF 16 / FP 16 | 16 | 2(主流) |
| INT 8 | 8 | 1 |
| INT 4 | 4 | 0.5 |
二、推理阶段显存
推理显存 = 权重 + KV Cache
1. 权重显存
公式:参数量 × 每参数字节数
例:7 B 模型,BF 16(2 字节) → 7 × 10⁹ × 2 = 14 GB
- 经验规律:参数量 × 2 ≈ 推理权重显存(GB)
2. KV Cache 显存
- 公式:2 × L × H × 2 ^1
- 第一个 2:K 和 V 两部分
- L:层数(如 7 B 模型约 32 层)
- H:隐藏层维度(如 7 B 模型约 4096)
- 最后一个 2:BF 16 占 2 字节
- 单 token 所需:2 × 32 × 4096 × 2 = 0.5 MB
- 实际场景:4096 个 token 时 → 0.5 MB × 4096 ≈ 2 GB;batch size[^2]=2 时 → 4 GB
3. 推理显存估算(7 B 模型)
- 权重 14 GB + KV Cache 约 2-4 GB ≈ 16-18 GB
- 单张 RTX 4090(24 GB)可轻松容纳
三、训练阶段显存
训练显存 = 静态显存 + 动态显存
1. 静态显存(必须保留)
| 组件 | 精度 | 大小 |
|---|---|---|
| 权重 (Weight) | BF 16 | 14 GB |
| 梯度 (Gradient) | BF 16 | 14 GB |
| 优化器状态 (Adam) - 权重备份 | FP 32 | 28 GB |
| 优化器状态 - 一阶动量 (m) | FP 32 | 28 GB |
| 优化器状态 - 二阶动量 (v) | FP 32 | 28 GB |
| 合计 | 112 GB |
经验公式:静态显存 ≈ 参数量 × 16(7 B × 16 = 112 GB)
2. 动态显存(激活值)
- 4096 token 输入时约需 40 GB
- 推导较复杂,直接记量级即可
3. 训练总显存估算
- 112 GB + 40 GB = 152 GB
- A 100 单卡 80 GB 无法容纳
- 至少需要 2 张 A 100(160 GB)
四、优化手段
梯度检查点 (Gradient Checkpointing)
- 原理:训练时不保存所有层的激活值,反向传播时重新计算
- 效果:将激活值显存从 40 GB 降至约 5-10 GB
- 代价:训练时间增加约 20-30%
- 适用:显存紧张但可接受训练时间延长的场景
五、三种特殊场景
1. 强化学习(PPO)
PPO 算法包含 4 个模型:
| 模型 | 训练/推理 | 静态显存 |
|---|---|---|
| Actor 模型 | 需要训练 | 112 GB |
| Critic 模型 | 需要训练 | 112 GB |
| Reference 模型 | 仅推理 | 14 GB |
| Reward 模型 | 仅推理 | 14 GB |
| 合计 | 252 GB |
另有训练过程中的 KV Cache 显存占用,PPO 是显存炸弹中的炸弹。
2. LoRA 微调
- 核心思想:冻结预训练权重,只更新少量低秩适配矩阵
- 参数量级:仅为原模型的 0.01% ~ 0.1%(7 B 模型约 7 M 参数)
- 训练显存:即使 7 B 模型 + BF 16 + LoRA,训练总显存可降至 20-40 GB
- 结论:单张 A 100(80 GB)足矣
3. MoE 模型(如 Mixtral 8×7 B)
- 总参数量 32 B,但每次只激活 3 B(每个 token 走部分专家)
- 常见误区:认为显存也按激活量计算(错误)
- 正确理解:32 B 参数必须全部加载到显存(随时待命)
- 训练时更耗显存:需要额外的负载均衡损失防止专家坍缩
⚠️ 错误与不精确之处说明
- 视频中算错总显存:原文 112 + 40 = 156 GB,正确应为 112 + 40 = 152 GB。不过这不影响整体结论(2 张 A 100 足够,但很紧张)。
- LoRA 参数量比例:视频称 1/7000(约万分之 1.4),实际 LoRA 一般在 0.1%~1%(千分之一到百分之一)量级。两者差了一个数量级,LoRA 的实际可训练参数量远大于视频所述。
- 推理阶段 KV Cache 计算简略:视频假设 batch size 为 2 时 KV Cache 线性翻倍(2 → 4 GB),但实际中还受分组查询注意力(GQA) 等因素影响,KV Cache 增长略小于线性。这一细节在高阶面试中可能被追问。
- 激活值计算过于简化:视频直接给出 40 GB 而未提供推导,实际激活值与 batch size、序列长度、模型宽度、深度均密切相关,在不同配置下差异巨大。
面试题与参考答案
Q 1:用 FP 16/BF 16 推理一个 7 B 模型需要多少显存?
参考答案:约 14-18 GB。其中权重占用 7 × 10⁹ × 2 ≈ 14 GB;KV Cache 按 4096 tokens 计算约 2 GB,batch size 增大或序列变长会相应增加。单张 RTX 4090(24 GB)或 RTX 3090(24 GB)均可容纳。
Q 2:全量微调一个 7 B 模型需要多少显存?单张 A 100(80 GB)够吗?
参考答案:不够。静态显存(权重+梯度+优化器)约 112 GB(经验公式:参数量 × 16),动态显存(激活值)约 40 GB,合计约 152 GB。单张 A 100(80 GB)无法容纳,至少需要 2 张 A 100(合计 160 GB)才能完整进行全量微调。如果使用梯度检查点(Gradient Checkpointing),可将激活值显存压缩至 5-10 GB,此时单卡勉强可用但训练速度大幅下降。
Q 3:LoRA 为什么能大幅降低训练显存?
参考答案:LoRA 冻结预训练权重(不需要计算梯度、优化器状态),只更新少量低秩矩阵。首先,原模型的权重仍占用 14 GB 显存(推理时也需要),但不再需要梯度(+14 GB)和优化器状态(+84 GB)。其次,可训练参数只有约 0.1%,对应的梯度、优化器状态仅需数十 MB。因此,7 B 模型 + LoRA 的总训练显存从 152 GB 降至约 20-40 GB,单张 A 100 轻松容纳。
Q 4:MoE 模型的总参数量是 32 B,但每次只激活 3 B,显存是不是只要 3 B 对应的量?
参考答案:不是。这是一个常见误区。MoE 虽然每次计算只激活部分专家,但所有专家参数都必须加载到显存中待命,因为不同的 token 会走不同的专家路径。因此,32 B 模型需要的显存与相同总参数量的 Dense 模型一致(约 64 GB 权重 + KV Cache)。MoE 节省的是计算量(FLOPs),而非显存占用。
Q 5:KV Cache 为什么会成为强化学习(PPO)的显存炸弹?
参考答案:PPO 训练中,Actor 模型需要反复进行推理(采样)来生成训练数据。每次推理都会产生 KV Cache,且 PPO 中通常包含 4 个模型(Actor、Critic、Reference、Reward),其中 Actor 和 Reference 都需要推理,产生的 KV Cache 数量是常规推理的 2-3 倍。此外,PPO 需要同时保存大量采样数据的 KV Cache 用于后续训练更新,进一步加剧显存压力。这是 PPO 训练成本极高的原因之一。
Q 6:梯度检查点(Gradient Checkpointing)的原理是什么?有无代价?
参考答案:梯度检查点通过空间换时间来降低显存。训练时只保存部分层的激活值,不在前向过程中保存所有层的激活结果,而是在反向传播过程中重新计算缺失的激活值。代价是训练时间增加约 20-30%(额外的前向计算消耗)。适用于显存紧张、对训练时间不敏感的场景,如单卡微调大模型。
## KV Cache 公式解析:2 × L × H × 2
### 公式含义拆解
| 因子 | 含义 | 7 B 模型取值 |
|------|------|-------------|
| **第一个 2** | K(Key)和 V(Value)**两个**矩阵 | 固定 |
| **L** | 层数(Layers) | 32 层 |
| **H** | 隐藏层维度(Hidden dimension) | 4096 |
| **最后一个 2** | BF 16 精度,每个数占 **2 字节** | 固定 |
### 直观理解
```
┌─────────────────────────────────────────────┐
│ 单个 Token 的 KV Cache │
├─────────────────────────────────────────────┤
│ L 层 × (K向量 4096维 + V向量 4096维) │
│ = 32 × (4096 + 4096) 个数字 │
│ = 32 × 8192 个数字 │
│ × 2 字节(BF16) │
│ = 524,288 字节 ≈ 0.5 MB │
└─────────────────────────────────────────────┘
```
### 举例:一批 4096 tokens 的 KV Cache
```
0.5 MB × 4096 ≈ 2 GB
```
- 每个 token 独立缓存一份 K 和 V
- 4096 个 token → 显存线性增长
### 一句话总结
> **公式 = 2(K+V)× 层数 × 维度 × 2 字节**,算的是一个 token 在所有层、所有 K/V 向量上占用的显存。
### 为什么是 H × H?
KV Cache 本质上是 attention 中的 **Q、K、V** 矩阵运算:
```
Attention = softmax(Q × K^T / √d) × V
```
- 每个 token 的 K 和 V 向量,维度都是 **H = 4096**
- 整个序列的 K 矩阵 = `seq_len × H`,V 矩阵 = `seq_len × H`
- 每个 token 单独缓存时占:`2 × H` 个数字(K + V 各一个向量)
因此单 token KV Cache = `2 × H × 2 字节` = `2 × 4096 × 2`,再乘以层数 L 得到 `2 × L × H × 2`。
[^2]: ## Batch Size(批次大小)是什么
**Batch Size** 指模型**一次性同时处理的样本数量**。
### 以 KV Cache 的例子理解
你的摘录中:
```
单 token 所需:2 × 32 × 4096 × 2 ≈ 0.5 MB
实际场景:4096 个 token 时 → 0.5 MB × 4096 ≈ 2 GB
batch size=2 时 → 4 GB
```
**单独 tokens**:一个样本由 4096 个 token 组成,KV Cache 占 2 GB。
**batch size=2**:同时处理 **2 个这样的样本**(每个样本 4096 tokens),KV Cache 需要为 2 个独立的对话各自缓存,因此翻倍为 4 GB。
### 类比理解
- **batch size=1**:一次只看一个病人 → 只读一份病历
- **batch size=4**:一次看四个病人 → 同时摊开四份病历
### 在训练和推理中的影响
| 场景 | batch size 越大 → |
|------|-------------------|
| 训练 | 梯度更新更稳定,但显存占用线性增长 |
| 推理 | 吞吐量提高,但 KV Cache 增长,延迟可能增加 |
### 典型值参考
| 场景 | 常见 batch size |
|------|----------------|
| 单卡推理 | 1~8 |
| 单卡训练(7 B) | 1~4(受显存限制) |
| 多卡训练(70 B) | 每卡 1~4,总 batch 可达 64~256 |
> **一句话**:batch size = 一次送进 GPU 的样本数,越大越“吃”显存,但训练梯度更稳、推理吞吐更高。