How LLMs Are Trained: From Petabytes to Parameters
Welcome back to our LLM series! If you think training a regular neural network is hard, imagine this: Training GPT-4 consumed more electricity than 1000 homes use in a year and cost over $100 million. Let's dive into how this monumental task is accomplished.
The Training Pipeline: From Raw Text to Smart Model
┌─────────────────────────────────────────────────────────────┐
│ The LLM Training Pipeline │
├──────────────┬────────────────┬──────────────────────────────┤
│ Stage 1 │ Stage 2 │ Stage 3 │
│ │ │ │
│ Data │ Pre-training │ Fine-tuning │
│ Preparation │ │ │
│ │ │ │
│ 90% of work │ $100M compute │ Alignment magic │
│ 10% of glory│ 2-6 months │ 1-2 weeks │
└──────────────┴────────────────┴──────────────────────────────┘
Stage 1: Data Preparation - The Unsung Hero
Before any training happens, we need massive amounts of high-quality text. Here's what the data pipeline looks like:
class DataPipeline:
def __init__(self):
self.sources = {
"common_crawl": "45TB raw web data",
"github": "1TB code",
"wikipedia": "20GB cleaned articles",
"books": "500GB from Project Gutenberg",
"academic_papers": "200GB from arXiv",
"social_media": "Reddit, Twitter (filtered)"
}
def process_pipeline(self, raw_text):
"""From raw bytes to training tokens"""
steps = [
self.deduplicate, # Remove duplicates
self.filter_quality, # Remove low-quality text
self.remove_pii, # Remove personal info
self.language_filter, # Keep mostly English
self.tokenize, # Convert to tokens
self.create_sequences # Create training examples
]
for step in steps:
raw_text = step(raw_text)
return raw_text
# Real numbers from Llama 3 training:
llama3_data = {
"raw_data_collected": "100+ TB",
"after_deduplication": "30 TB",
"after_filtering": "15 TB",
"final_tokens": "15 trillion",
"training_examples": "15 billion sequences"
}
The Tokenization Process
Tokenization converts text into numbers the model can understand:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8B")
text = "The Transformer architecture changed everything!"
tokens = tokenizer.encode(text)
# Result: [1, 510, 14199, 4969, 1091, 2082, 0]
# Let's see what this looks like:
print(tokenizer.decode([510])) # "The"
print(tokenizer.decode([14199])) # "Transformer"
print(tokenizer.decode([4969])) # "architecture"
# Different tokenizers handle things differently:
example = "I'm learning about LLMs!"
print(f"GPT-2 tokens: {len(tokenizer_gpt2.tokenize(example))}") # 7
print(f"Llama tokens: {len(tokenizer_llama.tokenize(example))}") # 6
print(f"T5 tokens: {len(tokenizer_t5.tokenize(example))}") # 8
Stage 2: Pre-training - The Next Token Prediction Marathon
The core task is simple: predict the next token. But at this scale, simple becomes revolutionary:
def compute_next_token_loss(batch_size=4, seq_length=2048):
"""
Simplified view of pre-training loss computation
"""
# Each training step processes:
tokens_per_batch = batch_size * seq_length # 4 * 2048 = 8192 tokens
# For Llama 3 (15 trillion tokens):
total_steps = 15_000_000_000_000 / tokens_per_batch
# That's ~1.8 billion training steps!
return total_steps
# The loss function is cross-entropy:
import torch
import torch.nn.functional as F
def pre_training_loss(logits, targets):
"""
logits: [batch_size, seq_len, vocab_size] - model predictions
targets: [batch_size, seq_len] - actual next tokens
"""
# Reshape for cross-entropy
logits_flat = logits.view(-1, logits.size(-1))
targets_flat = targets.view(-1)
# Standard cross-entropy loss
loss = F.cross_entropy(logits_flat, targets_flat)
return loss
The Scaling Laws: Chinchilla's Insight
The Chinchilla paper (Hoffmann et al., 2022) changed how we think about scaling:
def compute_optimal_scaling(compute_budget):
"""
Chinchilla's optimal scaling formula:
Compute (FLOPs) ≈ 6 × N × D
where N = parameters, D = training tokens
Optimal: D ≈ 20 × N
"""
# Given a compute budget in FLOPs
# We can solve for optimal N and D
# Example: 10^24 FLOPs budget
compute = 1e24
# Optimal parameters (N)
N_optimal = (compute / (6 * 20)) ** 0.5
# Optimal tokens (D)
D_optimal = 20 * N_optimal
return {
"parameters": int(N_optimal), # ~80B
"tokens": int(D_optimal), # ~1.6T
"compute_flops": compute
}
# This is why models are getting "smaller" but trained on more data:
scaling_comparison = {
"pre_chinchilla": {
"GPT-3": {"params": "175B", "tokens": "300B", "ratio": "1.7x"},
"Jurassic-1": {"params": "178B", "tokens": "300B", "ratio": "1.7x"}
},
"post_chinchilla": {
"Llama 2": {"params": "70B", "tokens": "2T", "ratio": "28x"},
"Chinchilla": {"params": "70B", "tokens": "1.4T", "ratio": "20x"},
"optimal": {"params": "70B", "tokens": "1.4T", "ratio": "20x"}
}
}
Stage 3: Distributed Training - Taming the Memory Beast
The Memory Problem
A 70B parameter model doesn't fit in GPU memory. Let's see why:
def calculate_memory_requirements(model_size_billion=70):
"""Calculate memory needed for a 70B parameter model"""
params = model_size_billion * 1_000_000_000
memory_breakdown = {
"parameters_fp32": params * 4, # 280 GB
"gradients_fp32": params * 4, # 280 GB
"optimizer_states": params * 8, # 560 GB (Adam: m and v)
"activations": params * 0.0014, # ~98 GB (rough estimate)
"temp_buffers": params * 0.0002, # ~14 GB
}
total = sum(memory_breakdown.values())
return {
"total_gb": total / 1_000_000_000,
"breakdown": memory_breakdown,
"h100_memory": 80, # H100 has 80GB
"gpus_needed": total / (80 * 1_000_000_000)
}
# Result: A 70B model needs ~1232 GB, or about 16 H100 GPUs just for memory!
Parallelism Strategies in Practice
Modern training uses multiple parallelism techniques simultaneously:
# Real-world configuration from Meta's Llama training
training_config = {
"model_size": "70B",
"gpus_used": 2048,
"parallelism_strategy": {
"tensor_parallelism": 8, # Split matrices across 8 GPUs
"pipeline_parallelism": 16, # 16 pipeline stages
"data_parallelism": 16, # 16 data parallel groups
},
"batch_size": {
"micro_batch": 4, # Per GPU
"global_batch": 4 * 16 * 16, # 1024 sequences
}
}
# How to set this up with PyTorch FSDP (Fully Sharded Data Parallel):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD, # Shard everything
mixed_precision=torch.bfloat16,
device_id=torch.cuda.current_device()
)
ZeRO (Zero Redundancy Optimizer)
ZeRO eliminates memory redundancy by partitioning states across GPUs:
class ZeROOptimizer:
"""Conceptual ZeRO implementation"""
def __init__(self, model, num_gpus):
self.num_gpus = num_gpus
# Partition optimizer states
self.partition_size = len(model.params) // num_gpus
def partition_states(self):
"""Divide optimizer states across GPUs"""
partitions = []
for i in range(self.num_gpus):
start = i * self.partition_size
end = start + self.partition_size
partition = {
"params": model.params[start:end],
"gradients": model.grads[start:end],
"optimizer_states": model.opt_states[start:end]
}
partitions.append(partition)
return partitions
def all_reduce_gradients(self):
"""Synchronize gradients across GPUs"""
# Each GPU only has part of gradients
# Need to aggregate for weight update
pass
Stage 4: Algorithmic Optimizations - The Secret Sauce
FlashAttention: I/O Optimization Masterpiece
Traditional attention has quadratic memory complexity. FlashAttention fixes this:
# Traditional attention (slow, memory-heavy)
def standard_attention(Q, K, V):
# Q, K, V: [batch, seq_len, d_model]
# 1. Compute attention scores: O(N²) memory!
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# scores shape: [batch, seq_len, seq_len]
# For 32K sequence length: 32K × 32K = 1B entries!
# 2. Softmax (needs full scores matrix)
attention_weights = torch.softmax(scores, dim=-1)
# 3. Apply to values
output = torch.matmul(attention_weights, V)
return output
# FlashAttention (fast, memory-efficient)
def flash_attention_impl(Q, K, V, block_size=256):
"""
Key optimizations:
1. Tiling: Process in blocks that fit in SRAM
2. Recomputation: Don't store intermediate matrices
3. Online softmax: Compute softmax block by block
"""
batch_size, seq_len, d_model = Q.shape
output = torch.zeros_like(Q)
# Process in blocks
for block_i in range(0, seq_len, block_size):
for block_j in range(0, seq_len, block_size):
# Load blocks to SRAM (fast memory)
Q_block = Q[:, block_i:block_i+block_size, :]
K_block = K[:, block_j:block_j+block_size, :]
V_block = V[:, block_j:block_j+block_size, :]
# Compute attention for these blocks
# ... specialized kernel implementation ...
# Accumulate results
output[:, block_i:block_i+block_size, :] += block_output
return output
# Memory comparison:
memory_comparison = {
"standard_attention_32k_seq": "32GB", # O(n²) storage
"flash_attention_32k_seq": "0.5GB", # O(n) storage
"speedup": "2-4× faster",
"memory_savings": "50-100× less memory"
}
Mixed Precision Training
Modern GPUs are optimized for lower precision math:
import torch
from torch.cuda.amp import autocast, GradScaler
# Mixed precision training pipeline
scaler = GradScaler() # For gradient scaling
for batch in dataloader:
optimizer.zero_grad()
# Forward pass in mixed precision
with autocast():
logits = model(batch['input_ids'])
loss = compute_loss(logits, batch['labels'])
# Backward pass with scaling
scaler.scale(loss).backward()
# Optimizer step with unscaling
scaler.step(optimizer)
scaler.update()
# Why mixed precision works:
# 1. FP16: 2 bytes vs FP32: 4 bytes (50% memory savings)
# 2. Tensor Cores: NVIDIA GPUs are optimized for FP16/BF16
# 3. Gradient scaling prevents underflow in FP16
# Precision formats:
precision_formats = {
"fp32": {"bits": 32, "range": "wide", "precision": "high"},
"bf16": {"bits": 16, "range": "wide like fp32", "precision": "lower"},
"fp16": {"bits": 16, "range": "narrow", "precision": "lower"},
"tf32": {"bits": 19, "range": "wide", "precision": "medium"},
}
Stage 5: Fine-tuning & Alignment
Supervised Fine-Tuning (SFT)
Pre-trained models complete text; SFT teaches them to follow instructions:
# SFT dataset format
sft_examples = [
{
"instruction": "Write Python code to sort a list",
"input": "",
"output": "def sort_list(lst):\n return sorted(lst)"
},
{
"instruction": "Explain quantum entanglement",
"input": "",
"output": "Quantum entanglement is a phenomenon where..."
}
]
# SFT training loop
def sft_training_step(model, batch):
# Format: [INST] Instruction [/INST] Response
formatted_prompts = format_instruction(batch['instruction'])
# Tokenize
inputs = tokenizer(
formatted_prompts + batch['output'],
return_tensors='pt',
padding=True,
truncation=True
)
# Forward pass
outputs = model(**inputs)
# Only compute loss on response part
# Mask out loss on instruction part
labels = inputs['input_ids'].clone()
instruction_length = len(tokenizer(formatted_prompts)['input_ids'])
labels[:, :instruction_length] = -100 # Ignore in loss
loss = outputs.loss
return loss
Parameter-Efficient Fine-Tuning: LoRA & QLoRA
Full fine-tuning is expensive. LoRA makes it affordable:
import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model
# LoRA: Low-Rank Adaptation
class LoRALayer(nn.Module):
def __init__(self, in_dim, out_dim, rank=8):
super().__init__()
# Original weights are frozen
self.original_layer = nn.Linear(in_dim, out_dim)
# LoRA adapters (trainable)
self.lora_A = nn.Linear(in_dim, rank, bias=False)
self.lora_B = nn.Linear(rank, out_dim, bias=False)
# Initialize
nn.init.zeros_(self.lora_B.weight)
def forward(self, x):
original_out = self.original_layer(x)
lora_out = self.lora_B(self.lora_A(x))
return original_out + lora_out
# Parameter count comparison
def calculate_parameter_savings(model_size=7_000_000_000):
"""Compare full vs LoRA fine-tuning"""
full_finetune = {
"trainable_params": model_size,
"memory_gb": model_size * 4 / 1e9, # FP32
"gpu_required": "A100 80GB or multiple GPUs"
}
lora_finetune = {
"rank": 8,
"trainable_params": model_size * (8 / 4096) * 2, # Rough estimate
"memory_gb": model_size * 4 / 1e9 + (model_size * 0.002), # +0.2%
"gpu_required": "Single 24GB GPU for 7B model"
}
return {"full": full_finetune, "lora": lora_finetune}
# QLoRA: 4-bit quantization + LoRA
from transformers import BitsAndBytesConfig
import bitsandbytes as bnb
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # NormalFloat4
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True # Even more compression!
)
# Load model in 4-bit
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3-8B",
quantization_config=bnb_config,
device_map="auto"
)
Stage 6: Evaluation - The Benchmarking Challenge
The Contamination Problem
def check_benchmark_contamination(training_data, benchmark_data):
"""Check if benchmark questions leaked into training"""
contamination_results = {}
for benchmark_name, benchmark_qs in benchmark_data.items():
# Check for exact matches
exact_matches = set(training_data) & set(benchmark_qs)
# Check for paraphrases (harder)
paraphrased_matches = check_paraphrases(training_data, benchmark_qs)
contamination_rate = (len(exact_matches) + len(paraphrased_matches)) / len(benchmark_qs)
contamination_results[benchmark_name] = {
"exact_matches": len(exact_matches),
"paraphrased_matches": len(paraphrased_matches),
"contamination_rate": contamination_rate
}
return contamination_results
# Common benchmarks and their issues:
benchmarks = {
"MMLU": {"tasks": 57, "subjects": "STEM, humanities", "issue": "High contamination"},
"GSM8K": {"tasks": "Grade school math", "issue": "Solutions online"},
"HumanEval": {"tasks": "Code generation", "issue": "GitHub contamination"},
"BigBench": {"tasks": 200+, "issue": "Diverse but noisy"},
}
Practical Evaluation with LM Evaluation Harness
from lm_eval import evaluator
# Evaluate a model on multiple benchmarks
results = evaluator.simple_evaluate(
model="hf_model",
model_args="pretrained=meta-llama/Llama-3-8B",
tasks=["mmlu", "gsm8k", "hellaswag"],
num_fewshot=5,
batch_size=8,
device="cuda:0"
)
# Print results
print(f"MMLU: {results['results']['mmlu']['acc']*100:.1f}%")
print(f"GSM8K: {results['results']['gsm8k']['acc']*100:.1f}%")
print(f"HellaSwag: {results['results']['hellaswag']['acc']*100:.1f}%")
Hands-On: Training Your Own Model
Step-by-Step Guide with Minimal Code
# 1. Data preparation
from datasets import load_dataset
dataset = load_dataset("wikitext", "wikitext-103-raw-v1")
# Or use your own data:
# dataset = load_dataset("json", data_files="my_data.jsonl")
# 2. Tokenization
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True, max_length=512)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# 3. Model initialization
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
model = AutoModelForCausalLM.from_pretrained("gpt2")
# 4. Training configuration
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=8, # Effective batch size = 4 * 8 = 32
fp16=True, # Mixed precision
save_steps=500,
eval_steps=500,
logging_steps=10,
learning_rate=5e-5,
weight_decay=0.01,
warmup_steps=500,
)
# 5. Training
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
)
trainer.train()
Advanced: Distributed Training on Multiple GPUs
# Launch distributed training with Accelerate
accelerate config # Configure your setup
accelerate launch train.py # Launch training
# Or with DeepSpeed
deepspeed --num_gpus=8 train.py --deepspeed ds_config.json
# Example DeepSpeed config (ds_config.json):
{
"train_batch_size": 32,
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 2,
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
},
"fp16": {
"enabled": true
}
}
Cost Analysis: What Does It Really Take?
def estimate_training_cost(model_size_billion=70, tokens_trillion=2):
"""Estimate cost of training an LLM"""
# Constants
h100_hourly_rate = 4 # $/hour (cloud pricing)
h100_flops = 1.98e15 # 1.98 PFLOPS for BF16
# Compute required (in FLOPs)
# Formula: C ≈ 6 * N * D
compute_flops = 6 * model_size_billion * 1e9 * tokens_trillion * 1e12
# GPU hours needed
gpu_hours = compute_flops / (h100_flops * 3600) # FLOPs / (FLOPS * seconds/hour)
# Cost
cost = gpu_hours * h100_hourly_rate
return {
"model_size": f"{model_size_billion}B",
"tokens": f"{tokens_trillion}T",
"compute_flops": f"{compute_flops:.1e}",
"gpu_hours": int(gpu_hours),
"gpu_count_for_1_month": int(gpu_hours / (30 * 24)),
"estimated_cost": f"${cost:,.0f}",
"carbon_emissions_tons": gpu_hours * 0.0004 # kgCO2/kWh * kW
}
# Example: Training different models
for model in [7, 13, 70, 700]:
result = estimate_training_cost(model, model * 0.02) # Chinchilla optimal
print(f"{result['model_size']}: {result['estimated_cost']}")
Key Takeaways & Best Practices
1. Start Small, Scale Smart
training_progression = [
"1. Toy model (1M params) on CPU",
"2. Small model (100M params) on single GPU",
"3. Medium model (1B params) with mixed precision",
"4. Large model (7B params) with FSDP",
"5. Very large model (70B+ params) with full parallelism"
]
2. Monitor Everything
# Essential metrics to track
metrics_to_monitor = {
"loss": {"trend": "should decrease", "warning": "plateaus or increases"},
"perplexity": {"trend": "should decrease", "ideal": "< 10 for good models"},
"gradient_norm": {"warning": "exploding (> 1.0) or vanishing (< 1e-6)"},
"learning_rate": {"schedule": "warmup then decay"},
"memory_usage": {"warning": "> 90% GPU memory"},
"throughput": {"tokens/sec": "measure efficiency"},
}
3. Debugging Common Issues
def debug_training_issues():
issues = {
"loss_not_decreasing": [
"Check learning rate (too high/low)",
"Verify data pipeline",
"Check for gradient issues",
"Try smaller batch size"
],
"out_of_memory": [
"Enable gradient checkpointing",
"Use mixed precision",
"Reduce batch size",
"Use memory-efficient attention",
"Enable optimizer state sharding"
],
"training_slow": [
"Enable FlashAttention",
"Increase batch size if memory allows",
"Use TF32/BF16 instead of FP32",
"Profile with PyTorch Profiler"
],
"model_not_converging": [
"Check data quality",
"Try different initialization",
"Adjust learning rate schedule",
"Add more regularization"
]
}
return issues
What's Next in LLM Training?
The field is evolving rapidly:
- Mixture of Experts (MoE) - Sparse activation for trillion-parameter models
- Multimodal training - Joint text/image/video understanding
- Continuous learning - Models that learn without catastrophic forgetting
- Efficiency breakthroughs - New architectures that need less compute
Resources & Tools
Essential Libraries
training_stack = {
"modeling": ["transformers", "triton", "flash-attention"],
"training": ["pytorch", "deepspeed", "accelerate"],
"data": ["datasets", "dataloader", "webdataset"],
"monitoring": ["wandb", "tensorboard", "mlflow"],
"deployment": ["vllm", "tensorrt-llm", "ggml"],
}
Learning Resources
- Hugging Face Course - Excellent practical guide
- Stanford CS329S: Machine Learning Systems Design - Systems perspective
- PyTorch Distributed Tutorials - Learn distributed training
- LLM-Performance-Engineering - Optimization techniques
Final Thoughts
Training LLMs is no longer just about having the biggest GPU cluster. It's about:
- Smart scaling (Chinchilla laws)
- Efficient systems (distributed training, FlashAttention)
- High-quality data (curation beats quantity)
- Careful monitoring (debugging at scale)
The biggest misconception? That bigger is always better. The truth: Better data and smarter training beats brute force.
** Discussion Time!**
- What's the largest model you've trained?
- What was your biggest training challenge?
- Any cool optimization tricks you've discovered?
** Try It Yourself:**
# Quick start with a small model
git clone https://github.com/karpathy/nanoGPT
cd nanoGPT
python data/shakespeare/prepare.py
python train.py config/train_shakespeare_char.py
Next up: We'll dive into **Alignment and RLHF* - how we make these powerful models actually helpful, honest, and harmless.*
Top comments (0)