DEV Community

Vuk Rosić
Vuk Rosić

Posted on

DeepSeek DeepGEMM 中文讲解

这个仓库是 DeepGEMM,一个专注于高效 FP8(8位浮点数)通用矩阵乘法(GEMM)的库,由 DeepSeek 团队开发。它支持细粒度的缩放(fine-grained scaling)和混合专家(MoE)模型中的分组 GEMM 运算。

主要特点

  1. FP8 GEMM 支持

    • 专为 NVIDIA Hopper 架构(如 H100 GPU)优化,利用 Tensor Core 进行高性能计算。
    • 由于 FP8 运算精度较低,采用 CUDA 核心进行两级累加(promotion)来保证计算精度。
  2. 分组 GEMM(Grouped GEMM)

    • 支持 连续布局(contiguous layout)掩码布局(masked layout),适用于 MoE 模型训练和推理。
    • 连续布局适用于训练或推理预填充阶段,而掩码布局适用于解码阶段(如 CUDA Graph 场景)。
  3. 权重梯度计算(Weight Gradient Kernels)

    • 支持 密集模型(Dense)MoE 模型 的反向传播计算。
  4. 轻量级 JIT(Just-In-Time)编译

    • 无需安装时编译,所有 CUDA 核心在运行时动态编译,减少部署复杂度。
    • 支持 NVCC 和 NVRTC(NVIDIA Runtime Compiler),后者编译速度更快(最高 10 倍)。
  5. 高性能优化

    • 采用 Hopper TMA(Tensor Memory Accelerator) 进行异步数据加载和存储。
    • 持久化 Warp 专业化(Persistent Warp-Specialization),优化数据移动和计算重叠。
    • FFMA SASS 指令交错(Interleaving),提升 FP8 运算效率。
    • 非对齐块大小(Unaligned Block Sizes),提高 SM(流式多处理器)利用率。
  6. 简洁的代码设计

    • 仅有一个核心 GEMM 核函数,便于理解和优化。

应用场景

  • 大模型训练与推理(如 Transformer、MoE 架构)。
  • 高性能计算(HPC),需要低精度(FP8)矩阵运算的场景。
  • 需要动态调整计算形状的任务(如变长序列处理)。

快速开始

环境要求

  • GPU: NVIDIA Hopper 架构(如 H100,需 sm_90a 支持)。
  • CUDA: 12.3+(推荐 12.8+ 以获得最佳性能)。
  • Python: 3.8+。
  • PyTorch: 2.1+。
  • CUTLASS: 3.6+(通过 Git 子模块引入)。

安装

git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git
python setup.py develop  # 开发模式
python setup.py install  # 安装
Enter fullscreen mode Exit fullscreen mode

使用示例

import deep_gemm

# 普通 FP8 GEMM
lhs = (torch.randn(4096, 7168, dtype=torch.float8_e4m3fn), torch.randn(4096, 56, dtype=torch.float32))  # [M, K], [M, K//128]
rhs = (torch.randn(2112, 7168, dtype=torch.float8_e4m3fn), torch.randn(17, 56, dtype=torch.float32))    # [N, K], [N//128, K//128]
out = torch.empty(4096, 2112, dtype=torch.bfloat16, device="cuda")
deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out)

# MoE 分组 GEMM(连续布局)
m_indices = torch.randint(0, 4, (8192,), device="cuda")  # 分组索引
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices)
Enter fullscreen mode Exit fullscreen mode

性能表现

  • 在 H800 GPU 上达到 1550 TFLOPS(FP8 算力)。
  • 性能优于或持平专家调优的库(如 CUTLASS)。

未来计划

  • 支持 BF16 运算。
  • 更大的块大小(N 维度扩展至 256)。
  • 优化功耗效率。
  • 支持 CUDA PDL(Pattern Description Language)。

总结

DeepGEMM 是一个高效、易用的 FP8 GEMM 库,特别适合大模型和 MoE 架构的计算优化。它的 JIT 编译设计和细粒度优化使其在多种矩阵形状下都能保持高性能。

论文引用

如需引用,请参考仓库中的 BibTeX 格式。


大语言模型(LLMs) 的训练和推理过程中,DeepGEMM 可以替代以下关键矩阵乘法(GEMM)运算,显著提升计算效率,尤其是在 FP8 低精度计算混合专家(MoE)模型 场景下:


1. 全连接层(Feed-Forward Network, FFN)

  • 标准 FFN

    LLM 中的全连接层(如 W_in @ XW_out @ (GELU(W_gate @ X)))通常占计算量的 30%~50%。

    替换方案

    使用 gemm_fp8_fp8_bf16_nt 替代 FP16/BF16 的 torch.matmul,利用 FP8 计算加速,同时通过两级累加保持精度。

  • MoE 模型的专家层

    MoE 模型中,不同 token 被路由到不同专家(如 Expert_i @ X),计算是稀疏的。

    替换方案

    使用分组 GEMM(m_grouped_gemm_fp8_fp8_bf16_nt_contiguousm_grouped_gemm_fp8_fp8_bf16_nt_masked),将多个专家的计算合并为一次批处理,减少启动开销。


2. 注意力机制(Attention)

  • QKV 投影(Q/K/V Matrices)

    计算 Q = X @ W_Q, K = X @ W_K, V = X @ W_V 时,可用 gemm_fp8_fp8_bf16_nt 加速。

    注意:Softmax 仍需 FP16/BF16 计算(FP8 精度不足)。

  • 注意力得分(Attention Scores)

    S = Q @ K^T 理论上可用 FP8 GEMM,但需配合缩放因子(因 FP8 动态范围小)。DeepGEMM 的细粒度缩放支持可能适用。


3. 反向传播中的梯度计算

  • 权重梯度(Weight Gradients)

    在反向传播中,计算 dW = X^T @ dY(如 FFN 层的 W.grad)。

    替换方案

    使用 wgrad_gemm_fp8_fp8_fp32_nt,直接以 FP8 输入计算 FP32 梯度,避免显式类型转换。

  • MoE 的专家梯度聚合

    使用 k_grouped_wgrad_gemm_fp8_fp8_fp32_nt 处理不同专家的梯度,避免逐专家计算。


4. 其他场景

  • 嵌入层(Embedding)的投影: 如 logits = hidden_states @ embedding_matrix.T,可用 FP8 GEMM 加速(需对齐维度)。
  • LoRA 适配器的低秩乘法: LoRA 的 A @ B 低秩矩阵乘法,适合 FP8 计算。

适用条件

  • 硬件要求:需 NVIDIA Hopper GPU(如 H100),支持 FP8 Tensor Core。
  • 精度权衡:FP8 可能影响模型收敛性,建议在训练后期或推理阶段使用。
  • 形状对齐:输入矩阵的 K 维度需对齐 128(DeepGEMM 的 BLOCK_K=128),否则需填充。

性能收益

  • 速度:FP8 理论算力是 BF16 的 4 倍(实际加速 2~3 倍)。
  • 显存:FP8 数据占用显存仅为 BF16 的 1/4,可处理更大 batch size。
  • MoE 优化:分组 GEMM 减少核函数启动次数,提升吞吐量。

示例代码(替换 HuggingFace 模型中的线性层)

from transformers import AutoModel
import deep_gemm

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b").cuda()

# 替换 FFN 层的矩阵乘法
def fp8_linear(x, weight, scales_x, scales_w):
    # x: [seq_len, hidden_dim], weight: [hidden_dim, out_dim]
    x_fp8 = (x.to(torch.float8_e4m3fn), scales_x)  # 伪代码,需实际量化
    w_fp8 = (weight.to(torch.float8_e4m3fn), scales_w)
    out = torch.empty(x.shape[0], weight.shape[1], dtype=torch.bfloat16, device="cuda")
    deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, w_fp8, out)
    return out

# 替换原模型的 forward 计算
model.layers[0].mlp.dense_h_to_4h = fp8_linear
Enter fullscreen mode Exit fullscreen mode

注意事项

  • 精度校准:FP8 需要动态缩放因子(DeepGEMM 已内置),建议在训练时统计 amax
  • 核函数选择:小矩阵(如 M < 128)可能无法充分利用 Tensor Core,需测试性能。

通过合理替换,DeepGEMM 可显著加速 LLM 的 训练迭代推理延迟,尤其适合 MoE 或超大模型场景。

Top comments (0)