DEV Community

Vuk Rosić
Vuk Rosić

Posted on

DeepGEMM Essentials: High-Performance FP8 Matrix Multiplication

DeepGEMM Essentials: High-Performance FP8 Matrix Multiplication

Google Colab

Master these concepts and you'll be able to leverage cutting-edge FP8 acceleration on Hopper H1000, H200 & H800 GOUs!

Part 1: Getting Started - Your First FP8 GEMM

import torch
import deep_gemm

# Create simple input matrices
m, n, k = 128, 256, 512
lhs = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
rhs = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
output = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)

print(f"LHS shape: {lhs.shape}")    # [128, 512]
print(f"RHS shape: {rhs.shape}")    # [256, 512] 
print(f"Output shape: {output.shape}")  # [128, 256]
Enter fullscreen mode Exit fullscreen mode

What happened: We created the basic tensors for matrix multiplication: LHS × RHS^T = Output.

Part 2: Understanding FP8 - Why It Matters

# .numel() returns the total number of elements in a tensor
small_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(f"Small tensor shape: {small_tensor.shape}")
print(f"Small tensor elements: {small_tensor.numel()}")  # 2 * 3 = 6

matrix_2d = torch.randn(50, 20)
print(f"2D matrix elements: {matrix_2d.numel()}")  # 50 * 20 = 1000

# Our actual tensors
print(f"LHS elements: {lhs.numel()}")  # 128 * 512 = 65,536
print(f"RHS elements: {rhs.numel()}")  # 256 * 512 = 131,072
Enter fullscreen mode Exit fullscreen mode
# Regular BF16 GEMM (what you normally do)
reference = lhs @ rhs.t()  # Standard PyTorch GEMM

# Check memory usage
bf16_memory = lhs.numel() * 2 + rhs.numel() * 2  # 2 bytes per BF16
fp8_memory = lhs.numel() * 1 + rhs.numel() * 1   # 1 byte per FP8

print(f"BF16 memory: {bf16_memory / 1024**2:.1f} MB")
print(f"FP8 memory: {fp8_memory / 1024**2:.1f} MB")
print(f"Memory saved: {(1 - fp8_memory/bf16_memory)*100:.1f}%")
Enter fullscreen mode Exit fullscreen mode

Key insight: FP8 uses half the memory while maintaining good accuracy with proper scaling.

Part 3: Converting to FP8 with Scaling

def cast_to_fp8_per_token(x: torch.Tensor):
    """Convert tensor to FP8 with per-token (per-row) scaling"""
    assert x.dim() == 2
    m, n = x.shape

    # Pad to 128-element boundaries (FP8 requirement)
    pad_size = (128 - (n % 128)) % 128
    if pad_size > 0:
        x = torch.nn.functional.pad(x, (0, pad_size), value=0)

    # Reshape for scaling calculation
    x_view = x.view(m, -1, 128)  # [m, n/128, 128]

    # Find max absolute value per 128-element block
    x_amax = x_view.abs().float().amax(dim=2).clamp(1e-4)  # [m, n/128]

    # Scale to FP8 range (448.0 is max representable value)
    fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
    scale_factors = (x_amax / 448.0)

    return fp8_data.view(m, -1)[:, :n], scale_factors

# Convert our matrices
lhs_fp8, lhs_scales = cast_to_fp8_per_token(lhs)
print(f"Original: {lhs.dtype}, Converted: {lhs_fp8.dtype}")
print(f"Scale factors shape: {lhs_scales.shape}")
Enter fullscreen mode Exit fullscreen mode

Critical concept: Scaling prevents overflow and maintains precision in the limited FP8 range.

Part 4: Block-wise Scaling for RHS

def cast_to_fp8_per_block(x: torch.Tensor):
    """Convert tensor to FP8 with per-block scaling (128x128 blocks)"""
    m, n = x.shape

    # Pad to 128x128 blocks
    padded_m = ((m + 127) // 128) * 128
    padded_n = ((n + 127) // 128) * 128

    x_padded = torch.zeros((padded_m, padded_n), dtype=x.dtype, device=x.device)
    x_padded[:m, :n] = x

    # Reshape into 128x128 blocks
    x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)

    # Find max per block
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)

    # Scale to FP8
    x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
    scale_factors = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))

    return x_scaled.view_as(x_padded)[:m, :n], scale_factors

# Convert RHS with block scaling
rhs_fp8, rhs_scales = cast_to_fp8_per_block(rhs)
print(f"RHS FP8 shape: {rhs_fp8.shape}")
print(f"RHS scales shape: {rhs_scales.shape}")
Enter fullscreen mode Exit fullscreen mode

Why different scaling: LHS uses fine-grained scaling, RHS uses coarser blocks for efficiency.

Part 5: Preparing Tensors for DeepGEMM

# DeepGEMM requires specific tensor layouts
from deep_gemm import get_col_major_tma_aligned_tensor

# LHS scales must be transposed and TMA-aligned
lhs_scales_aligned = get_col_major_tma_aligned_tensor(lhs_scales)

# RHS scales must be contiguous
assert rhs_scales.is_contiguous()

# Package the inputs
lhs_input = (lhs_fp8, lhs_scales_aligned)
rhs_input = (rhs_fp8, rhs_scales)

print("✓ Tensors prepared for DeepGEMM")
print(f"LHS scales alignment: {lhs_scales_aligned.stride()}")
Enter fullscreen mode Exit fullscreen mode

TMA requirement: Tensor Memory Accelerator needs specific memory alignment for optimal performance.

Part 6: Your First DeepGEMM Call

# Perform the FP8 GEMM
deep_gemm.gemm_fp8_fp8_bf16_nt(lhs_input, rhs_input, output)

# Verify correctness
reference = lhs @ rhs.t()
error = torch.abs(output - reference).max().item()
relative_error = (error / torch.abs(reference).max().item()) * 100

print(f"Max absolute error: {error:.6f}")
print(f"Relative error: {relative_error:.3f}%")
print("✓ FP8 GEMM completed successfully!")
Enter fullscreen mode Exit fullscreen mode

Result: High-performance FP8 matrix multiplication with automatic kernel optimization.

Part 7: Understanding the Performance Gain

import time

def benchmark_gemm(func, *args, num_runs=10):
    # Warmup
    for _ in range(3):
        func(*args)
    torch.cuda.synchronize()

    # Timing
    start = time.time()
    for _ in range(num_runs):
        func(*args)
    torch.cuda.synchronize()

    return (time.time() - start) / num_runs

# Benchmark both versions
fp8_time = benchmark_gemm(deep_gemm.gemm_fp8_fp8_bf16_nt, lhs_input, rhs_input, output)
bf16_time = benchmark_gemm(lambda x, y, out: out.copy_(x @ y.t()), lhs, rhs, reference)

# Calculate throughput (TFLOPS)
ops = 2 * m * n * k  # Multiply-accumulate operations
fp8_tflops = ops / fp8_time / 1e12
bf16_tflops = ops / bf16_time / 1e12

print(f"FP8 GEMM:  {fp8_time*1000:.2f}ms ({fp8_tflops:.1f} TFLOPS)")
print(f"BF16 GEMM: {bf16_time*1000:.2f}ms ({bf16_tflops:.1f} TFLOPS)")
print(f"Speedup: {bf16_time/fp8_time:.1f}x")
Enter fullscreen mode Exit fullscreen mode

Performance: FP8 can achieve 2-3x speedup on modern GPUs while using half the memory.

Part 8: Grouped GEMM - Processing Multiple Experts

# Simulate MoE (Mixture of Experts) scenario
num_experts = 4
tokens_per_expert = [128, 96, 112, 144]  # Variable tokens per expert
expert_dim = 512

# Create contiguous tensor for all tokens
total_tokens = sum(tokens_per_expert)
alignment = deep_gemm.get_m_alignment_for_contiguous_layout()  # 128

# Align each expert's token count
aligned_tokens = [((t + alignment - 1) // alignment) * alignment for t in tokens_per_expert]
total_aligned = sum(aligned_tokens)

print(f"Original tokens: {tokens_per_expert}")
print(f"Aligned tokens: {aligned_tokens}")
print(f"Total aligned: {total_aligned}")
Enter fullscreen mode Exit fullscreen mode

MoE insight: Different experts process different numbers of tokens - grouping improves efficiency.

Part 9: Setting Up Grouped GEMM Data

# Create inputs for grouped GEMM
lhs_grouped = torch.randn((total_aligned, k), device='cuda', dtype=torch.bfloat16)
rhs_grouped = torch.randn((num_experts, n, k), device='cuda', dtype=torch.bfloat16)
output_grouped = torch.empty((total_aligned, n), device='cuda', dtype=torch.bfloat16)

# Create mapping tensor
m_indices = torch.empty(total_aligned, device='cuda', dtype=torch.int32)
start = 0
for expert_id, (orig_tokens, aligned_tokens) in enumerate(zip(tokens_per_expert, aligned_tokens)):
    # Real tokens get expert ID
    m_indices[start:start + orig_tokens] = expert_id
    # Padding tokens get -1 (ignored)
    m_indices[start + orig_tokens:start + aligned_tokens] = -1
    start += aligned_tokens

print(f"Mapping tensor shape: {m_indices.shape}")
print(f"Expert assignments: {m_indices[:20]}")  # First 20 tokens
Enter fullscreen mode Exit fullscreen mode

Mapping: Each token knows which expert should process it.

Part 10: Converting Grouped Data to FP8

# Convert LHS (same as before)
lhs_grouped_fp8, lhs_grouped_scales = cast_to_fp8_per_token(lhs_grouped)
lhs_grouped_scales = get_col_major_tma_aligned_tensor(lhs_grouped_scales)

# Convert each expert's RHS separately
rhs_grouped_fp8 = torch.empty_like(rhs_grouped, dtype=torch.float8_e4m3fn)
rhs_grouped_scales = torch.empty((num_experts, (n + 127) // 128, (k + 127) // 128), 
                                device='cuda', dtype=torch.float32)

for expert_id in range(num_experts):
    rhs_grouped_fp8[expert_id], rhs_grouped_scales[expert_id] = cast_to_fp8_per_block(rhs_grouped[expert_id])

# Package inputs
lhs_grouped_input = (lhs_grouped_fp8, lhs_grouped_scales)
rhs_grouped_input = (rhs_grouped_fp8, rhs_grouped_scales)

print("✓ Grouped data converted to FP8")
Enter fullscreen mode Exit fullscreen mode

Expert-wise: Each expert has its own scaling factors for optimal precision.

Part 11: Running Grouped GEMM

# Perform grouped GEMM
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
    lhs_grouped_input, 
    rhs_grouped_input, 
    output_grouped, 
    m_indices
)

# Verify by computing reference
reference_grouped = torch.zeros_like(output_grouped)
start = 0
for expert_id, aligned_tokens in enumerate(aligned_tokens):
    end = start + aligned_tokens
    reference_grouped[start:end] = lhs_grouped[start:end] @ rhs_grouped[expert_id].t()
    start = end

# Mask out padding tokens for comparison
valid_mask = (m_indices != -1).unsqueeze(1)
output_masked = torch.where(valid_mask, output_grouped, torch.zeros_like(output_grouped))
reference_masked = torch.where(valid_mask, reference_grouped, torch.zeros_like(reference_grouped))

error = torch.abs(output_masked - reference_masked).max().item()
print(f"Grouped GEMM error: {error:.6f}")
print("✓ Grouped GEMM completed successfully!")
Enter fullscreen mode Exit fullscreen mode

Validation: Compare against standard computation to ensure correctness.

Part 12: Weight Gradient GEMM

# For training: compute weight gradients
def setup_weight_gradient():
    m_grad, k_grad, n_grad = 256, 1024, 512

    # Activations (forward pass)
    activations = torch.randn((m_grad, k_grad), device='cuda', dtype=torch.bfloat16)

    # Gradient w.r.t. output (from backprop)
    grad_output = torch.randn((m_grad, n_grad), device='cuda', dtype=torch.bfloat16)

    # Weight gradient accumulator (typically has residual)
    weight_grad = torch.randn((n_grad, k_grad), device='cuda', dtype=torch.float) * 0.1

    return activations, grad_output, weight_grad

activations, grad_output, weight_grad = setup_weight_gradient()

# Convert to FP8
act_fp8, act_scales = cast_to_fp8_per_token(activations)
grad_fp8, grad_scales = cast_to_fp8_per_token(grad_output)

# Prepare inputs (both need transposed scales)
act_input = (act_fp8, get_col_major_tma_aligned_tensor(act_scales))
grad_input = (grad_fp8, get_col_major_tma_aligned_tensor(grad_scales))

print(f"Weight gradient shape: {weight_grad.shape}")
print(f"Accumulator dtype: {weight_grad.dtype}")  # FP32 for precision
Enter fullscreen mode Exit fullscreen mode

Training context: Weight gradients accumulate many small updates - need FP32 precision.

Part 13: Computing Weight Gradients

# Compute weight gradients with accumulation
original_grad = weight_grad.clone()

deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(grad_input, act_input, weight_grad)

# Verify: grad_output^T @ activations + original_grad
reference_update = grad_output.float().t() @ activations.float()
expected_grad = original_grad + reference_update

error = torch.abs(weight_grad - expected_grad).max().item()
relative_error = error / torch.abs(expected_grad).max().item()

print(f"Weight gradient error: {error:.6f}")
print(f"Relative error: {relative_error*100:.3f}%")
print("✓ Weight gradient computation successful!")
Enter fullscreen mode Exit fullscreen mode

Accumulation: New gradients are added to existing values, enabling mini-batch training.

Part 14: Performance Monitoring

def analyze_gemm_performance(m, n, k, operation="forward"):
    # Theoretical peak performance
    # H100 has ~1600 TFLOPS FP8 peak
    ops = 2 * m * n * k
    peak_time = ops / 1600e12  # Theoretical minimum time

    # Memory bandwidth
    fp8_bytes = (m * k + n * k) * 1  # FP8 inputs
    bf16_bytes = m * n * 2           # BF16 output  
    scale_bytes = ((m * k) // 128 + (n * k) // 128) * 4  # FP32 scales
    total_bytes = fp8_bytes + bf16_bytes + scale_bytes

    # H100 has ~3TB/s memory bandwidth
    bandwidth_time = total_bytes / 3e12

    print(f"\n{operation.upper()} GEMM Analysis (M={m}, N={n}, K={k}):")
    print(f"Operations: {ops/1e9:.1f} GigaOps")
    print(f"Compute bound time: {peak_time*1000:.2f}ms")
    print(f"Memory bound time: {bandwidth_time*1000:.2f}ms")
    print(f"Bottleneck: {'Compute' if peak_time > bandwidth_time else 'Memory'}")

# Analyze our configurations
analyze_gemm_performance(128, 256, 512, "forward")
analyze_gemm_performance(256, 512, 1024, "weight_grad")
Enter fullscreen mode Exit fullscreen mode

Performance tuning: Understanding compute vs memory bottlenecks helps optimize configurations.

Part 15: Advanced Configuration

# Control SM utilization for better efficiency
original_sms = deep_gemm.get_num_sms()
print(f"Default SMs: {original_sms}")

# Use fewer SMs for smaller problems to save power
deep_gemm.set_num_sms(original_sms // 2)
print(f"Reduced SMs: {deep_gemm.get_num_sms()}")

# Run a smaller GEMM
small_lhs = torch.randn((64, 256), device='cuda', dtype=torch.bfloat16)
small_rhs = torch.randn((128, 256), device='cuda', dtype=torch.bfloat16)
small_out = torch.empty((64, 128), device='cuda', dtype=torch.bfloat16)

small_lhs_fp8, small_lhs_scales = cast_to_fp8_per_token(small_lhs)
small_rhs_fp8, small_rhs_scales = cast_to_fp8_per_block(small_rhs)

deep_gemm.gemm_fp8_fp8_bf16_nt(
    (small_lhs_fp8, get_col_major_tma_aligned_tensor(small_lhs_scales)),
    (small_rhs_fp8, small_rhs_scales),
    small_out
)

# Restore original setting
deep_gemm.set_num_sms(original_sms)
print("✓ SM configuration demonstrated")
Enter fullscreen mode Exit fullscreen mode

Resource management: Control GPU utilization for power efficiency and multi-tenancy.

Part 16: Debugging and Validation

def validate_fp8_conversion(original, fp8_data, scales):
    """Check if FP8 conversion preserves data accurately"""
    # Reconstruct original from FP8
    if fp8_data.dim() == 2:
        # Per-token scaling
        m, n = fp8_data.shape
        fp8_view = fp8_data.view(m, -1, 128)
        scales_expanded = scales.unsqueeze(2)
        reconstructed = fp8_view.float() * scales_expanded
        reconstructed = reconstructed.view(m, -1)[:, :original.shape[1]]

    # Compare
    abs_error = torch.abs(original.float() - reconstructed).max().item()
    rel_error = abs_error / torch.abs(original.float()).max().item()

    print(f"FP8 conversion error: {abs_error:.6f} ({rel_error*100:.3f}%)")
    return abs_error < 1e-2  # Reasonable threshold for FP8

# Validate our conversions
lhs_valid = validate_fp8_conversion(lhs, lhs_fp8, lhs_scales)
rhs_valid = validate_fp8_conversion(rhs, rhs_fp8[0], rhs_scales[0])

print(f"LHS conversion valid: {lhs_valid}")
print(f"RHS conversion valid: {rhs_valid}")
Enter fullscreen mode Exit fullscreen mode

Quality assurance: Always validate FP8 conversions to ensure acceptable precision loss.

Part 17: Memory Optimization

def estimate_memory_usage(shapes, operation="gemm"):
    """Estimate GPU memory usage for DeepGEMM operations"""
    m, n, k = shapes

    # Input tensors
    lhs_fp8 = m * k * 1          # FP8
    lhs_scales = m * ((k + 127) // 128) * 4  # FP32
    rhs_fp8 = n * k * 1          # FP8  
    rhs_scales = ((n + 127) // 128) * ((k + 127) // 128) * 4  # FP32

    # Output
    if operation == "gemm":
        output = m * n * 2       # BF16
    else:  # weight_grad
        output = m * n * 4       # FP32

    # Temporary workspace (estimated)
    workspace = max(m, n) * 1024 * 4  # Conservative estimate

    total = lhs_fp8 + lhs_scales + rhs_fp8 + rhs_scales + output + workspace

    print(f"Memory usage for {shapes}:")
    print(f"  Inputs: {(lhs_fp8 + lhs_scales + rhs_fp8 + rhs_scales) / 1024**2:.1f} MB")
    print(f"  Output: {output / 1024**2:.1f} MB")
    print(f"  Workspace: {workspace / 1024**2:.1f} MB")
    print(f"  Total: {total / 1024**2:.1f} MB")

    return total

# Estimate for different problem sizes
estimate_memory_usage((1024, 2048, 4096), "gemm")
estimate_memory_usage((2048, 4096, 8192), "weight_grad")
Enter fullscreen mode Exit fullscreen mode

Capacity planning: Understand memory requirements for different model sizes.

Part 18: Integration with Training Loops

class FP8LinearLayer:
    """Example of integrating DeepGEMM into a training loop"""

    def __init__(self, in_features, out_features):
        self.weight = torch.randn((out_features, in_features), 
                                device='cuda', dtype=torch.bfloat16)
        self.weight_grad = torch.zeros_like(self.weight, dtype=torch.float)

    def forward(self, x):
        # Convert inputs to FP8
        x_fp8, x_scales = cast_to_fp8_per_token(x)
        w_fp8, w_scales = cast_to_fp8_per_block(self.weight)

        # Prepare DeepGEMM inputs
        x_input = (x_fp8, get_col_major_tma_aligned_tensor(x_scales))
        w_input = (w_fp8, w_scales)

        # Allocate output
        output = torch.empty((x.shape[0], self.weight.shape[0]), 
                           device='cuda', dtype=torch.bfloat16)

        # Forward pass
        deep_gemm.gemm_fp8_fp8_bf16_nt(x_input, w_input, output)
        return output

    def backward(self, x, grad_output):
        # Convert to FP8
        x_fp8, x_scales = cast_to_fp8_per_token(x)
        grad_fp8, grad_scales = cast_to_fp8_per_token(grad_output)

        # Prepare inputs
        x_input = (x_fp8, get_col_major_tma_aligned_tensor(x_scales))
        grad_input = (grad_fp8, get_col_major_tma_aligned_tensor(grad_scales))

        # Compute weight gradients: grad_output^T @ x
        deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(grad_input, x_input, self.weight_grad)

# Demo usage
layer = FP8LinearLayer(512, 256)
x = torch.randn((128, 512), device='cuda', dtype=torch.bfloat16)

# Forward pass
y = layer.forward(x)
print(f"Forward output shape: {y.shape}")

# Backward pass
grad_y = torch.randn_like(y)
layer.backward(x, grad_y)
print(f"Weight grad shape: {layer.weight_grad.shape}")
print("✓ Training loop integration demonstrated")
Enter fullscreen mode Exit fullscreen mode

Real-world usage: How to integrate DeepGEMM into actual neural network training.

Key Takeaways

  1. FP8 = 2x Memory Savings: Half the storage with proper scaling
  2. Scaling is Critical: Per-token and per-block strategies maintain precision
  3. TMA Alignment: Required for optimal hardware utilization
  4. Grouped Operations: Efficient for MoE and variable-size batches
  5. JIT Compilation: Automatic kernel optimization for each shape
  6. Memory Layout Matters: Column-major scales, contiguous tensors
  7. FP32 Accumulation: Use higher precision for gradients

Practice Challenge

# Create an MoE layer with 8 experts
# Process a batch with variable expert utilization
# Measure memory savings vs standard implementation

num_experts = 8
expert_tokens = [64, 128, 96, 112, 88, 144, 72, 104]  # Realistic distribution
hidden_dim = 2048

# Your implementation here:
# 1. Set up grouped GEMM inputs
# 2. Convert to FP8 
# 3. Run DeepGEMM
# 4. Compare with standard PyTorch
# 5. Measure performance and memory usage

print("Challenge: Implement efficient MoE with DeepGEMM!")
Enter fullscreen mode Exit fullscreen mode

Top comments (0)