这个仓库是 DeepGEMM,一个专注于高效 FP8(8位浮点数)通用矩阵乘法(GEMM)的库,由 DeepSeek 团队开发。它支持细粒度的缩放(fine-grained scaling)和混合专家(MoE)模型中的分组 GEMM 运算。
主要特点
-
FP8 GEMM 支持
- 专为 NVIDIA Hopper 架构(如 H100 GPU)优化,利用 Tensor Core 进行高性能计算。
- 由于 FP8 运算精度较低,采用 CUDA 核心进行两级累加(promotion)来保证计算精度。
-
分组 GEMM(Grouped GEMM)
- 支持 连续布局(contiguous layout) 和 掩码布局(masked layout),适用于 MoE 模型训练和推理。
- 连续布局适用于训练或推理预填充阶段,而掩码布局适用于解码阶段(如 CUDA Graph 场景)。
-
权重梯度计算(Weight Gradient Kernels)
- 支持 密集模型(Dense) 和 MoE 模型 的反向传播计算。
-
轻量级 JIT(Just-In-Time)编译
- 无需安装时编译,所有 CUDA 核心在运行时动态编译,减少部署复杂度。
- 支持 NVCC 和 NVRTC(NVIDIA Runtime Compiler),后者编译速度更快(最高 10 倍)。
-
高性能优化
- 采用 Hopper TMA(Tensor Memory Accelerator) 进行异步数据加载和存储。
- 持久化 Warp 专业化(Persistent Warp-Specialization),优化数据移动和计算重叠。
- FFMA SASS 指令交错(Interleaving),提升 FP8 运算效率。
- 非对齐块大小(Unaligned Block Sizes),提高 SM(流式多处理器)利用率。
-
简洁的代码设计
- 仅有一个核心 GEMM 核函数,便于理解和优化。
应用场景
- 大模型训练与推理(如 Transformer、MoE 架构)。
- 高性能计算(HPC),需要低精度(FP8)矩阵运算的场景。
- 需要动态调整计算形状的任务(如变长序列处理)。
快速开始
环境要求
-
GPU: NVIDIA Hopper 架构(如 H100,需
sm_90a
支持)。 - CUDA: 12.3+(推荐 12.8+ 以获得最佳性能)。
- Python: 3.8+。
- PyTorch: 2.1+。
- CUTLASS: 3.6+(通过 Git 子模块引入)。
安装
git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git
python setup.py develop # 开发模式
python setup.py install # 安装
使用示例
import deep_gemm
# 普通 FP8 GEMM
lhs = (torch.randn(4096, 7168, dtype=torch.float8_e4m3fn), torch.randn(4096, 56, dtype=torch.float32)) # [M, K], [M, K//128]
rhs = (torch.randn(2112, 7168, dtype=torch.float8_e4m3fn), torch.randn(17, 56, dtype=torch.float32)) # [N, K], [N//128, K//128]
out = torch.empty(4096, 2112, dtype=torch.bfloat16, device="cuda")
deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
# MoE 分组 GEMM(连续布局)
m_indices = torch.randint(0, 4, (8192,), device="cuda") # 分组索引
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices)
性能表现
- 在 H800 GPU 上达到 1550 TFLOPS(FP8 算力)。
- 性能优于或持平专家调优的库(如 CUTLASS)。
未来计划
- 支持 BF16 运算。
- 更大的块大小(N 维度扩展至 256)。
- 优化功耗效率。
- 支持 CUDA PDL(Pattern Description Language)。
总结
DeepGEMM 是一个高效、易用的 FP8 GEMM 库,特别适合大模型和 MoE 架构的计算优化。它的 JIT 编译设计和细粒度优化使其在多种矩阵形状下都能保持高性能。
论文引用
如需引用,请参考仓库中的 BibTeX 格式。
在 大语言模型(LLMs) 的训练和推理过程中,DeepGEMM 可以替代以下关键矩阵乘法(GEMM)运算,显著提升计算效率,尤其是在 FP8 低精度计算 和 混合专家(MoE)模型 场景下:
1. 全连接层(Feed-Forward Network, FFN)
标准 FFN:
LLM 中的全连接层(如W_in @ X
和W_out @ (GELU(W_gate @ X))
)通常占计算量的 30%~50%。
替换方案:
使用gemm_fp8_fp8_bf16_nt
替代 FP16/BF16 的torch.matmul
,利用 FP8 计算加速,同时通过两级累加保持精度。MoE 模型的专家层:
MoE 模型中,不同 token 被路由到不同专家(如Expert_i @ X
),计算是稀疏的。
替换方案:
使用分组 GEMM(m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
或m_grouped_gemm_fp8_fp8_bf16_nt_masked
),将多个专家的计算合并为一次批处理,减少启动开销。
2. 注意力机制(Attention)
QKV 投影(Q/K/V Matrices):
计算Q = X @ W_Q
,K = X @ W_K
,V = X @ W_V
时,可用gemm_fp8_fp8_bf16_nt
加速。
注意:Softmax 仍需 FP16/BF16 计算(FP8 精度不足)。注意力得分(Attention Scores):
S = Q @ K^T
理论上可用 FP8 GEMM,但需配合缩放因子(因 FP8 动态范围小)。DeepGEMM 的细粒度缩放支持可能适用。
3. 反向传播中的梯度计算
权重梯度(Weight Gradients):
在反向传播中,计算dW = X^T @ dY
(如 FFN 层的W.grad
)。
替换方案:
使用wgrad_gemm_fp8_fp8_fp32_nt
,直接以 FP8 输入计算 FP32 梯度,避免显式类型转换。MoE 的专家梯度聚合:
使用k_grouped_wgrad_gemm_fp8_fp8_fp32_nt
处理不同专家的梯度,避免逐专家计算。
4. 其他场景
-
嵌入层(Embedding)的投影:
如
logits = hidden_states @ embedding_matrix.T
,可用 FP8 GEMM 加速(需对齐维度)。 -
LoRA 适配器的低秩乘法:
LoRA 的
A @ B
低秩矩阵乘法,适合 FP8 计算。
适用条件
- 硬件要求:需 NVIDIA Hopper GPU(如 H100),支持 FP8 Tensor Core。
- 精度权衡:FP8 可能影响模型收敛性,建议在训练后期或推理阶段使用。
-
形状对齐:输入矩阵的
K
维度需对齐 128(DeepGEMM 的BLOCK_K=128
),否则需填充。
性能收益
- 速度:FP8 理论算力是 BF16 的 4 倍(实际加速 2~3 倍)。
- 显存:FP8 数据占用显存仅为 BF16 的 1/4,可处理更大 batch size。
- MoE 优化:分组 GEMM 减少核函数启动次数,提升吞吐量。
示例代码(替换 HuggingFace 模型中的线性层)
from transformers import AutoModel
import deep_gemm
model = AutoModel.from_pretrained("meta-llama/Llama-2-7b").cuda()
# 替换 FFN 层的矩阵乘法
def fp8_linear(x, weight, scales_x, scales_w):
# x: [seq_len, hidden_dim], weight: [hidden_dim, out_dim]
x_fp8 = (x.to(torch.float8_e4m3fn), scales_x) # 伪代码,需实际量化
w_fp8 = (weight.to(torch.float8_e4m3fn), scales_w)
out = torch.empty(x.shape[0], weight.shape[1], dtype=torch.bfloat16, device="cuda")
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, w_fp8, out)
return out
# 替换原模型的 forward 计算
model.layers[0].mlp.dense_h_to_4h = fp8_linear
注意事项
-
精度校准:FP8 需要动态缩放因子(DeepGEMM 已内置),建议在训练时统计
amax
。 -
核函数选择:小矩阵(如
M < 128
)可能无法充分利用 Tensor Core,需测试性能。
通过合理替换,DeepGEMM 可显著加速 LLM 的 训练迭代 和 推理延迟,尤其适合 MoE 或超大模型场景。
Top comments (0)