Note as of 2 Aug 2025, GMT+8 1am: Updated article to include activation checkpointing.
Table of Contents
- Motivation
- Why study ZeRO
- Setup
- Why make a copy of the model weights for optimizer states
- Why do we have to make copies of the momentum and variance and not just recompute them on the fly?
- Stage 1
- Stage 2
- Stage 3
Motivation
This article seeks to implement by hand the Zero Redundancy Optimizer (ZeRO). The goal is to build an intuition for how to implement fully sharded data parallelism to enable training at scale.
This article has been written with the assistance of Google Gemini 2.5 Pro, and with reference to Stanford CS336 lecture 7 on parallelism.
Why study ZeRO
ZeRO optimizes memory and improves training speed while increasing the model size that can be efficiently trained (source - ZeRO: Memory Optimizations Toward Training Trillion Parameter Models). This is done by sharding the model's parameters, gradients, and optimizer states across multiple GPUs, with minimal memory redundancy in data-parallel training. ZeRO is implemented in three stages:
- Stage 1: Partitions the optimizer states across the data-parallel workers.
- Stage 2: In addition to partitioning the optimizer states, this stage also partitions the gradients.
- Stage 3: In addition to gradients and optimizer states, this stage also partitions the model parameters themselves.
The following diagram is lifted from the original ZeRO paper itself. Assuming only a relatively small model (by modern standard) of 7.5 billion parameters, we can expect 120 GB of memory consumption on each GPU. This far exceeds the 80GB VRAM on a H100. We can bring down the memory consumption per device to 31.4GB, 16.6GB, and finally 1.9GB with ZeRO stage 1, 2 and 3 respectively.
(source - ZeRO: Memory Optimizations Toward Training Trillion Parameter Models)
Setup
We assume the following setup:
- Hardware: 2 GPUs (GPU0, GPU1)
-
Model: a tiny model with just 4 weights.
-
W_fp16
= [w1, w2, w3, w4]: The low-precision (fp16) weights used for the fast forward and backward passes.
-
- Optimizer States:
-
W_fp32
= [w1_fp32, w2_fp32, w3_fp32, w4_fp32]: The high-precision (fp32) master copy of the weights. - Momentum
M
and VarianceV
: Momentum M = [m1, m2, m3, m4] and Variance V = [v1, v2, v3, v4]. These are also in fp32.
-
Why make a copy of the model weights for optimizer states
The original fp16 weights are upcasted to fp32 and replicated for optimizer states for reducing rounding errors during training.
Let's say a gradient for a specific weight is very small. For example, gradient = 0.00001. The learning rate might also be small, say learning_rate = 0.001.
The update to the weight would be learning_rate * gradient = 0.00000001.
- In FP32 (32-bit): A 32-bit floating-point number has enough precision to represent this tiny update. You can add 0.00000001 to an existing weight and the value will actually change.
- In FP16 (16-bit): A 16-bit float has much less precision. A number as small as 0.00000001 might be rounded down to zero. If you add zero to your FP16 weight, the weight does not change. The update is completely lost.
If this happens for many steps in a row, your model stops learning for that parameter.
Why do we have to make copies of the momentum and variance and not just recompute them on the fly?
The key innovation of optimizers like SGD with Momentum, RMSprop, and Adam is that they do not just look at the gradient from the current mini-batch. They maintain a memory of past gradients to make a more intelligent update. This memory helps them navigate the loss landscape more effectively, avoiding oscillations and speeding up convergence.
This article assumes Adam is used.
A. Why We Must Store Momentum (The First Moment)
Momentum is formally defined as an exponentially moving average (EMA) of the gradients.
The update rule is:
momentum_t = β₁ * momentum_{t-1} + (1 - β₁) * gradient_t
We can analyze the following:
-
momentum_t
: The momentum we are calculating for the current stept
. -
momentum_{t-1}
: This is the critical part. To calculate the new momentum, you absolutely need the value of the momentum from the previous step (t-1
). -
gradient_t
: The gradient from the current mini-batch. -
β₁
: The decay rate (e.g., 0.9). It controls how much "memory" of past gradients is kept.
The "On-the-Fly" Problem: If you tried to calculate momentum "on the fly" without storing it, you would not have momentum_{t-1}
as the value is gone from the last step. It is also computationally infeasible to re-calculate the momentum of the entire history of training at every step.
Analogy: Imagine you're calculating the average temperature for the last 30 days. On day 31, to update your average, you don't re-measure the temperature of all 30 previous days. You simply take the average you already calculated for the first 30 days and combine it with the new measurement. momentum_{t-1}
is that stored average. It summarizes the entire history of gradients up to this point.
B. Why We Must Store Variance (The Second Moment)
The logic is identical for the variance term (the "v" in Adam). It is an EMA of the squared gradients.
The update rule is:
variance_t = β₂ * variance_{t-1} + (1 - β₂) * (gradient_t)²
Again, to calculate the variance for the current step t
, you must have the stored value of the variance from the previous step, variance_{t-1}
. Without storing this state between training steps, the algorithm simply breaks. It loses its "memory" of how much the gradients have been varying in the past, which is the key information it uses to adapt the learning rate for each parameter.
Stage 1
The core idea is to eliminate memory redundancy of the optimizer states. This is significant for popular optimizers such as Adam, where for every model parameter (weight), we store its optimizer states. For an optimizer like Adam, this consists of a 4-byte fp32 master copy of the weight, a 4-byte momentum, and a 4-byte variance. This totals 12 bytes of memory for optimizer states per parameter, compared to just 2 bytes for the fp16 parameter itself.
Stage 1 consists of the following four steps, which I will go through one at a time:
- Step 1: Every rank computes a full gradient on their subset of the batch.
- Step 2: Reduce-scatter the gradients - incur a communication cost proportional to the number of parameters (specifically, (N-1)/N * #params for a ring algorithm, where N is the number of GPUs).
- Step 3: Each machine updates their param using their gradient + state.
- Step 4: All gather the parameters - incur #params communication cost.
Initialization and partitioning
In ZeRO Stage 1, we start with a standard Data Parallel setup where the batch of training data is split among the GPUs. The key innovation is that while the model weights are replicated, the optimizer-related states are partitioned to save memory.
-
Model: A simple 2-layer neural network:
-
h = w1*x1 + w2*x2
-
a = ReLU(h)
(ReLU(x) is max(0, x)) -
y_pred = w3*a + w4
-
Loss Function: Half Mean Squared Error:
Loss = 0.5 * (y_pred - y_true)^2
-
Replicated Data (Present on BOTH GPU 0 and GPU 1):
- The fp16 model weights:
W_fp16 = [w1=2.0, w2=-3.0, w3=1.0, w4=0.5]
- The fp16 model weights:
-
Input Data (Different for each GPU):
- GPU 0:
Input = [x1=1, x2=3]
,True Label (y_true) = 5
- GPU 1:
Input = [x1=2, x2=1]
,True Label (y_true) = 7
- GPU 0:
-
Partitioned Optimizer States (these will be used in the parameter update step):
- GPU 0 holds the fp32 master weights, momentum, and variance for
w1
andw2
. - GPU 1 holds the fp32 master weights, momentum, and variance for
w3
andw4
.
- GPU 0 holds the fp32 master weights, momentum, and variance for
Memory Saving: Each GPU holds the full fp16 model (2 bytes/param) but only half of the fp32 model (4 bytes/param), half the momentum (4 bytes/param), and half the variance (4 bytes/param). We have eliminated the redundancy for all 12 bytes/parameter of optimizer states.
Step 1: Compute full gradient
Each GPU computes a full gradient on their subset of the batch.
On GPU0: Forward and backward pass
GPU 0 performs its calculations completely independently.
Forward Pass (GPU 0)
- Calculate
h
:h = w1*x1 + w2*x2 = (2.0 * 1) + (-3.0 * 3) = 2 - 9 = -7.0
- Apply Activation
a
:a = ReLU(h) = ReLU(-7.0) = 0
- Calculate Prediction
y_pred
:y_pred = w3*a + w4 = (1.0 * 0) + 0.5 = 0.5
- Calculate Loss:
Loss = 0.5 * (0.5 - 5)^2 = 0.5 * (-4.5)^2 = 0.5 * 20.25 = 10.125
Note on Activation Checkpointing: During this forward pass, to save memory, we do not store the intermediate values of h
and a
. They will be recomputed during the backward pass.
Backward Pass (GPU 0)
Now we go backward using the chain rule to find the gradient of the Loss with respect to each weight.
- Gradient wrt
y_pred
:dL/dy_pred
= (y_pred - y_true) = 0.5 - 5 = -4.5
- Recompute activation
a
for gradient calculation: Before calculating gradients forw3
andw4
, we must first recompute the activationa
which was discarded after the forward pass. This requires re-running the first part of the model.- Recompute
h
:h = w1*x1 + w2*x2 = (2.0 * 1) + (-3.0 * 3) = -7.0
- Recompute
a
:a = ReLU(h) = ReLU(-7.0) = 0
- Recompute
- Gradient wrt
w4
:dL/dw4 = (dL/dy_pred) * (dy_pred/dw4)
The derivative ofw3*a + w4
wrtw4
is1
.= -4.5 * 1 = -4.5
- Gradient wrt
w3
:dL/dw3 = (dL/dy_pred) * (dy_pred/dw3)
The derivative ofw3*a + w4
wrtw3
isa
.= -4.5 * a = -4.5 * 0 = 0
- Gradient wrt
a
:dL/da = (dL/dy_pred) * (dy_pred/da)
The derivative ofw3*a + w4
wrta
isw3
.= -4.5 * w3 = -4.5 * 1.0 = -4.5
- Gradient wrt
h
:dL/dh = (dL/da) * (da/dh)
The derivative ofReLU(h)
is0
ifh<0
and1
ifh>0
. Sinceh=-7.0
, the derivative is0
.= -4.5 * 0 = 0
- Gradient wrt
w2
andw1
: SincedL/dh
is0
, any gradient flowing back from it will also be zero.-
dL/dw2 = (dL/dh) * (dh/dw2) = 0 * x2 = 0
-
dL/dw1 = (dL/dh) * (dh/dw1) = 0 * x1 = 0
-
Result from GPU 0: Grad0 = [grad_w1=0, grad_w2=0, grad_w3=0, grad_w4=-4.5]
On GPU 1: Forward and Backward Pass
GPU 1 performs the same calculations with its own data.
Forward Pass (GPU 1)
- Calculate
h
:h = w1*x1 + w2*x2 = (2.0 * 2) + (-3.0 * 1) = 4 - 3 = 1.0
- Apply Activation
a
:a = ReLU(h) = ReLU(1.0) = 1.0
- Calculate Prediction
y_pred
:y_pred = w3*a + w4 = (1.0 * 1.0) + 0.5 = 1.5
- Calculate Loss:
Loss = 0.5 * (1.5 - 7)^2 = 0.5 * (-5.5)^2 = 0.5 * 30.25 = 15.125
Note on Activation Checkpointing: Again, the values for h
and a
are discarded to save memory.
Backward Pass (GPU 1)
- Gradient wrt
y_pred
:dL/dy_pred
= (y_pred - y_true) = 1.5 - 7 = -5.5
- Recompute activation
a
for gradient calculation:- Recompute
h
:h = w1*x1 + w2*x2 = (2.0 * 2) + (-3.0 * 1) = 1.0
- Recompute
a
:a = ReLU(h) = ReLU(1.0) = 1.0
- Recompute
- Gradient wrt
w4
:dL/dw4 = -5.5 * 1 = -5.5
- Gradient wrt
w3
:dL/dw3 = -5.5 * a = -5.5 * 1.0 = -5.5
- Gradient wrt
a
:dL/da = -5.5 * w3 = -5.5 * 1.0 = -5.5
- Gradient wrt
h
:dL/dh = (dL/da) * (da/dh)
Sinceh=1.0
, the derivative ofReLU(h)
is1
.= -5.5 * 1 = -5.5
- Gradient wrt
w2
:dL/dw2 = (dL/dh) * (dh/dw2)
The derivative ofw1*x1 + w2*x2
wrtw2
isx2
.= -5.5 * x2 = -5.5 * 1 = -5.5
- Gradient wrt
w1
:dL/dw1 = (dL/dh) * (dh/dw1)
The derivative ofw1*x1 + w2*x2
wrtw1
isx1
.= -5.5 * x1 = -5.5 * 2 = -11.0
Result from GPU 1: Grad1 = [grad_w1=-11.0, grad_w2=-5.5, grad_w3=-5.5, grad_w4=-5.5]
Step 2: Average and Scatter Gradients (Reduce-scatter)
The system performs a Reduce-scatter operation. This single operation accomplishes two things:
One, average the gradient tensors from all GPUs (i.e. reduce) and two, scatter the result:
-
g_avg1 = (0 + (-11.0)) / 2 = -5.5
-
g_avg2 = (0 + (-5.5)) / 2 = -2.75
-
g_avg3 = (0 + (-5.5)) / 2 = -2.75
-
g_avg4 = (-4.5 + (-5.5)) / 2 = -5.0
The synchronized gradient tensor on both GPUs is:
Grad_avg = [-5.5, -2.75, -2.75, -5.0]
Two, instead of creating a full tensor with these results on every GPU, the reduce-scatter
operation immediately sends the relevant partition to its destination GPU. The full Grad_avg is a conceptual intermediate that is never fully materialized in every GPU's memory:
- GPU 0 is responsible for
w1
andw2
, so it receives the first partition of the averaged gradient:
Grad_partition_0 = [-5.5, -2.75]
- GPU 1 is responsible for
w3
andw4
, so it receives the second partition of the averaged gradient:
Grad_partition_1 = [-2.75, -5.0]
Step 3: Update partitioned parameters
Each GPU now updates it assigned partition of the parameters using its local partition of the gradient. The calculations happen in full fp32 precision.
Let us first recap the parameters on both GPUs:
- On BOTH GPU 0 and GPU 1:
- The current fp16 model is:
W_fp16 = [w1=2.0, w2=-3.0, w3=1.0, w4=0.5]
- The averaged gradient tensor is:
Grad_avg = [g1=-5.5, g2=-2.75, g3=-2.75, g4=-5.0]
- The current fp16 model is:
- Partitioned Optimizer States:
- GPU 0 holds the fp32 master weights, momentum (m), and variance (v) for
w1
andw2
. - GPU 1 holds the fp32 master weights, momentum (m), and variance (v) for
w3
andw4
.
- GPU 0 holds the fp32 master weights, momentum (m), and variance (v) for
For this example, let's define our optimizer hyperparameters. We'll use the Adam optimizer.
- Learning Rate (
lr
):0.1
-
beta1
:0.9
-
beta2
:0.999
-
epsilon
:1e-8
- We will assume this is the very first training step (
t=1
), so all initial momentums and variances are0
.
Each GPU was able to compute the gradients for all four weights because the calculation only required two things it had locally:
- The complete set of
W_fp16
weights. - Its own unique mini-batch of data.
I will now show how GPU 0 uses [-5.5, -2.75]
to update its partition (w1
, w2
), and GPU 1 uses [-2.75, -5.0]
to update its partition (w3
, w4
). This is where each GPU works only on its assigned partition. The calculations happen in full fp32 precision.
On GPU 0 (Responsible for w1 and w2)
GPU 0 uses g1=-5.5
and g2=-2.75
from the Grad_avg
tensor.
Updating w1
:
- Retrieve local states:
w1_fp32 = 2.0
,m0_1 = 0
,v0_1 = 0
. - Update momentum:
m1_1 = beta1*m0_1 + (1-beta1)*g1 = 0.9*0 + 0.1*(-5.5) = -0.55
. - Update variance:
v1_1 = beta2*v0_1 + (1-beta2)*g1^2 = 0.999*0 + 0.001*(-5.5)^2 = 0.03025
. - Bias Correction (since t=1):
-
m_hat = m1_1 / (1 - beta1^t) = -0.55 / (1 - 0.9) = -5.5
. -
v_hat = v1_1 / (1 - beta2^t) = 0.03025 / (1 - 0.999) = 30.25
.
-
- Calculate new
w1_fp32
:w1'_fp32 = w1_fp32 - lr * m_hat / (sqrt(v_hat) + epsilon)
= 2.0 - 0.1 * (-5.5) / (sqrt(30.25) + 1e-8)
= 2.0 - 0.1 * (-5.5) / 5.5 = 2.0 - 0.1 * (-1) = 2.1
. - Cast to fp16 for the computational model:
w1'_fp16 = 2.1
.
Updating w2
:
- Retrieve local states:
w2_fp32 = -3.0
,m0_2 = 0
,v0_2 = 0
. - Update momentum:
m1_2 = 0.9*0 + 0.1*(-2.75) = -0.275
. - Update variance:
v1_2 = 0.999*0 + 0.001*(-2.75)^2 = 0.0075625
. - Bias Correction:
-
m_hat = -0.275 / 0.1 = -2.75
. -
v_hat = 0.0075625 / 0.001 = 7.5625
.
-
- Calculate new
w2_fp32
:w2'_fp32 = w2_fp32 - lr * m_hat / (sqrt(v_hat) + epsilon)
= -3.0 - 0.1 * (-2.75) / (sqrt(7.5625) + 1e-8)
= -3.0 - 0.1 * (-2.75) / 2.75 = -3.0 - 0.1 * (-1) = -2.9
. - Cast to fp16:
w2'_fp16 = -2.9
.
State on GPU 0 after updates: Its local fp16 model is now [2.1, -2.9, 1.0, 0.5]
(partially updated).
On GPU 1 (Responsible for w3 and w4)
GPU 1 uses g3=-2.75
and g4=-5.0
from the Grad_avg
tensor.
Updating w3
:
- Retrieve local states:
w3_fp32 = 1.0
,m0_3 = 0
,v0_3 = 0
. - The calculations are identical to
w2
since the gradient is the same (-2.75
). - Calculate new
w3_fp32
:w3'_fp32 = 1.0 - 0.1 * (-1) = 1.1
. - Cast to fp16:
w3'_fp16 = 1.1
.
Updating w4
:
- Retrieve local states:
w4_fp32 = 0.5
,m0_4 = 0
,v0_4 = 0
. - The calculations are similar to
w1
since the gradient magnitude is close. -
m1_4 = -0.5
,v1_4 = 0.025
. -
m_hat = -5.0
,v_hat = 25
. - Calculate new
w4_fp32
:w4'_fp32 = 0.5 - 0.1 * (-5.0) / (sqrt(25) + 1e-8)
= 0.5 - 0.1 * (-1) = 0.6
. - Cast to fp16:
w4'_fp16 = 0.6
.
State on GPU 1 after updates: Its local fp16 model is now [2.0, -3.0, 1.1, 0.6]
(partially updated).
At this point, the W_fp16
models on the two GPUs are inconsistent with each other. This must be fixed before the next forward pass.
Step 4: Synchronize the Model (All-Gather)
The final step is to make the W_fp16
computational model identical on all GPUs. This is done by having each GPU broadcast the new fp16 weights from its partition to everyone else.
-
Broadcast:
- GPU 0 sends its updated partition:
[w1'=2.1, w2'=-2.9]
. - GPU 1 sends its updated partition:
[w3'=1.1, w4'=0.6]
.
- GPU 0 sends its updated partition:
-
Gather and Update:
- GPU 0 receives
[1.1, 0.6]
from GPU 1. It uses this to overwrite the old values forw3
andw4
in its localW_fp16
copy. - GPU 1 receives
[2.1, -2.9]
from GPU 0. It uses this to overwrite the old values forw1
andw2
in its localW_fp16
copy.
- GPU 0 receives
Final State (End of Training Step)
After the all-gather
, the state on both GPUs is now consistent and ready for the next iteration:
- Final Synchronized
W_fp16
on both GPUs:[2.1, -2.9, 1.1, 0.6]
- Updated Optimizer States (Still Partitioned):
- On GPU 0:
w1_fp32=2.1, m1_1=-0.55, v1_1=0.03025
andw2_fp32=-2.9, m1_2=-0.275, v1_2=0.0075625
. - On GPU 1:
w3_fp32=1.1, m1_3=-0.275, v1_3=0.0075625
andw4_fp32=0.6, m1_4=-0.5, v1_4=0.025
.
- On GPU 0:
The loop is now complete. The system successfully updated the weights, saved a significant amount of memory by not replicating the optimizer states, and ended with a consistent model ready for the next forward pass. This completes ZeRO stage 1.
Why do we break up All-reduce into Reduce-scatter and All-gather?
We know from my previous article that Reduce-scatter + All-gather = All-reduce
, in the sense that:
- The outputs on both sides of the equation are equivalent
- The communication cost on both sides of the equation are equivalent (assuming both implement the ring-based algorithms)
With a Naive Distributed Data Parallel (DPP) where each GPU computes gradients on its local data batch, implementing All-reduce in one step means each GPU sends its calculated gradients to all other GPUs, and they all compute the sum (or average) and end up with the identical, full gradient tensor.
By splitting up All-reduce into two steps:
- Reduce-scatter: instead of every GPU getting the full averaged gradient tensor, this operation computes the average and then immediately scatters partitions of the averaged gradients to the GPUs responsible for them. Each GPU only receives the specific slice of the gradients it needs to update its portion of the optimizer states.
- All-gather: After each GPU updates its local partition of the model's parameters, they need to be synchronized across all GPUs for the next forward pass. This is done with an All-gather operation, where each GPU broadcasts its updated parameter partition to all other GPUs.
Stage 2
In Stage 1, we partitioned the optimizer states, which provides the largest memory savings. However, a notable memory redundancy remains: during the backward pass, every GPU must store the full gradient tensor until the final reduction step. For very large models, this gradient tensor can itself become a memory bottleneck.
The core idea of Stage 2 is to eliminate this gradient redundancy by overlapping communication and computation. Instead of waiting for the entire backward pass to finish, Stage 2 begins reducing gradients as soon as they are calculated for a specific layer and then immediately discards them. This approach lowers the peak memory usage and can improve performance by hiding the network latency of communication behind the computation of the next layer.
Stage 2 consists of the following steps:
- Step 1: Incremental backward pass and overlapped gradient reduction. With activation checkpointing, during the forward pass, intermediate activations are not stored. During the backward pass, they are recomputed layer-by-layer just before the gradients for that layer are calculated.
- Step 2: Update partitioned parameters
- Step 3: All-gather the parameters
Step 1: Incremental Backward Pass & Overlapped Gradient Reduction
This is the key innovation of Stage 2. The backward pass and gradient communication are interleaved. As we backpropagate through the model, layer by layer, we perform the following steps:
- 1a. Recompute Activations (due to checkpointing): Before calculating the gradients for a specific layer, the necessary input activations for that layer (which were discarded during the forward pass) must be recomputed.
- 1b. Incremental Reduction: Immediately after a layer's gradients are computed, a
reduce
operation is triggered for them. In practice, this is often areduce-scatter
that sums the gradients and sends the correct partition to the GPU that owns that parameter slice. - 1c. Immediate Memory Deallocation: Once a layer's gradients have been sent to the appropriate worker, they are no longer needed on the local GPU and that memory is freed. This ensures the GPU never needs to hold the gradients for the entire model at once.
Let's trace this step-by-step with our running example.
A. Backpropagation Through the Final Layer (w3, w4)
The backward pass proceeds in reverse order, so it encounters the layer involving w3
and w4
first.
-
Recompute Activations: To calculate gradients for this layer, we need the activation
a
. Since it was not stored during the forward pass, each GPU must recompute it.- GPU 0: Re-runs the first layer's computation:
h = (2.0*1) + (-3.0*3) = -7.0
, so the recomputeda = ReLU(-7.0) = 0
. - GPU 1: Re-runs the first layer's computation:
h = (2.0*2) + (-3.0*1) = 1.0
, so the recomputeda = ReLU(1.0) = 1.0
.
- GPU 0: Re-runs the first layer's computation:
-
Compute Local Gradients: From our work in Stage 1, we know the locally computed gradients for this layer on each GPU:
- GPU 0:
Grad0_layer2 = [grad_w3=0, grad_w4=-4.5]
- GPU 1:
Grad1_layer2 = [grad_w3=-5.5, grad_w4=-5.5]
- GPU 0:
-
Immediate Reduce-Scatter: The system does not wait. It immediately performs a reduce-scatter operation only on these gradients. The goal is to compute the average and deliver the result to GPU 1, which is responsible for
w3
andw4
.- Average
grad_w3
:(0 + (-5.5)) / 2 = -2.75
- Average
grad_w4
:(-4.5 + (-5.5)) / 2 = -5.0
- Average
Deliver Partition: The resulting averaged gradient partition
[-2.75, -5.0]
is sent to GPU 1.-
Free Memory:
- GPU 0 no longer needs
Grad0_layer2
. This memory is freed. - GPU 1 has received its final partition, so it can free the temporary memory it used to hold
Grad1_layer2
.
- GPU 0 no longer needs
B. Backpropagation Through the First Layer (w1, w2)
While the first communication may still be in flight, the backward pass continues to the next layer, which involves w1
and w2
.
Recompute Activations: In this case, the input to the first layer is the raw input data (
x1
,x2
), which is already in memory, so no recomputation is needed for the activations required by this specific layer's gradient calculation.-
Compute Local Gradients: We again use the results from our Stage 1 calculation:
- GPU 0:
Grad0_layer1 = [grad_w1=0, grad_w2=0]
- GPU 1:
Grad1_layer1 = [grad_w1=-11.0, grad_w2=-5.5]
- GPU 0:
-
Immediate Reduce-Scatter: A second, independent reduce-scatter is performed on these gradients. The target is GPU 0, which owns
w1
andw2
.- Average
grad_w1
:(0 + (-11.0)) / 2 = -5.5
- Average
grad_w2
:(0 + (-5.5)) / 2 = -2.75
- Average
Deliver Partition: The resulting averaged gradient partition
[-5.5, -2.75]
is sent to GPU 0.Free Memory: Both GPUs can now free the memory used to store the local gradients for
w1
andw2
.
State After Incremental Reduction:
At the end of this overlapped process, the backward pass is complete. The GPUs did not store the full gradient tensor or the full set of activations from the forward pass. Instead, they each hold only the final, averaged partition that they need for the optimizer step.
- GPU 0 Memory: Holds the gradient partition
Grad_part0 = [-5.5, -2.75]
. - GPU 1 Memory: Holds the gradient partition
Grad_part1 = [-2.75, -5.0]
.
Step 2: Update Partitioned Parameters
This step is now identical to Stage 1. Each GPU has the precise gradient information it needs to update its slice of the fp32 master parameters. We use the same Adam optimizer settings as before (lr=0.1
, beta1=0.9
, beta2=0.999
, t=1
).
On GPU 0 (Responsible for w1 and w2)
GPU 0 uses its local gradient partition [-5.5, -2.75]
.
Updating w1
:
- States:
w1_fp32 = 2.0
,m0_1 = 0
,v0_1 = 0
. Gradientg1 = -5.5
. - Update: The calculation is identical to Stage 1, resulting in:
-
m1_1 = -0.55
-
v1_1 = 0.03025
-
w1'_fp32 = 2.1
-
- Cast to fp16:
w1'_fp16 = 2.1
.
Updating w2
:
- States:
w2_fp32 = -3.0
,m0_2 = 0
,v0_2 = 0
. Gradientg2 = -2.75
. - Update: The calculation is identical to Stage 1, resulting in:
-
m1_2 = -0.275
-
v1_2 = 0.0075625
-
w2'_fp32 = -2.9
-
- Cast to fp16:
w2'_fp16 = -2.9
.
On GPU 1 (Responsible for w3 and w4)
GPU 1 uses its local gradient partition [-2.75, -5.0]
.
Updating w3
:
- States:
w3_fp32 = 1.0
,m0_3 = 0
,v0_3 = 0
. Gradientg3 = -2.75
. - Update: The calculation is identical to Stage 1, resulting in:
-
w3'_fp32 = 1.1
-
- Cast to fp16:
w3'_fp16 = 1.1
.
Updating w4
:
- States:
w4_fp32 = 0.5
,m0_4 = 0
,v0_4 = 0
. Gradientg4 = -5.0
. - Update: The calculation is identical to Stage 1, resulting in:
-
w4'_fp32 = 0.6
-
- Cast to fp16:
w4'_fp16 = 0.6
.
At this point, just like in Stage 1, the replicated W_fp16
models on the two GPUs are inconsistent.
Step 3: All-Gather the Parameters
This final step is also identical to Stage 1. An all-gather
operation is required to ensure the replicated W_fp16
model is consistent on all GPUs before the next forward pass can begin.
-
Broadcast Partitions:
- GPU 0 sends its newly updated parameters:
[2.1, -2.9]
. - GPU 1 sends its newly updated parameters:
[1.1, 0.6]
.
- GPU 0 sends its newly updated parameters:
Gather and Reconstruct: Each GPU receives the partition from the other and assembles the full, synchronized model.
Final State (End of Training Step)
The training step is complete. The state on both GPUs is now consistent and ready for the next iteration:
- Final Synchronized
W_fp16
on both GPUs:[2.1, -2.9, 1.1, 0.6]
- Updated Optimizer States (Still Partitioned):
- On GPU 0:
w1_fp32=2.1, w2_fp32=-2.9
, along with their new momentum and variance. - On GPU 1:
w3_fp32=1.1, w4_fp32=0.6
, along with their new momentum and variance.
- On GPU 0:
Stage 2 achieves the exact same numerical outcome as Stage 1 but does so with a lower peak memory footprint by overlapping communication and computation during the backward pass.
Stage 3
In Stage 2, we eliminated redundancies in optimizer states and gradients. However, one last major redundancy exists: the model parameters (W_fp16
) themselves are still fully replicated on every GPU.
The purpose of Stage 3 is to eliminate this final parameter redundancy. It achieves this by partitioning the parameters (W_fp16
) across all GPUs from the very beginning. Each GPU is now responsible for storing and updating only its small slice of the complete model.
This creates a new challenge: during the forward and backward pass, a GPU will often need access to parameters that it does not hold in its own memory. ZeRO solves this by dynamically fetching the required parameter partitions from other GPUs just-in-time for the computation of each layer, and discarding them immediately after use. This is accomplished through a sequence of all-gather
, compute, and reduce-scatter
operations that are tightly integrated into the training loop.
Stage 3 consists of the following steps:
- Step 1: Initialization (different from stage 1)
- Step 2: Forward pass (layer-by-layer, discarding weights and activations)
- Step 3: Backward pass (recomputing activations and integrating gradient reduction)
- Step 4: Optimizer update (perfectly local)
- Step 5: Synchronization
Step-by-Step Example of Stage 3
Step 1: Initialization (The Key Change)
The memory layout is now fundamentally different. Everything is partitioned from the start.
-
GPU 0 Memory:
- Parameter Partition:
W_part0 = [w1_fp16=2.0, w2_fp16=-3.0]
- Optimizer States (fp32 master weights, momentum, variance) for
w1
andw2
only.
- Parameter Partition:
-
GPU 1 Memory:
- Parameter Partition:
W_part1 = [w3_fp16=1.0, w4_fp16=0.5]
- Optimizer States (fp32 master weights, momentum, variance) for
w3
andw4
only.
- Parameter Partition:
-
Input Data (Same as before):
- GPU 0:
Input = [x1=1, x2=3]
,True Label (y_true) = 5
- GPU 1:
Input = [x1=2, x2=1]
,True Label (y_true) = 7
- GPU 0:
Memory Saving: This provides a massive memory reduction. Neither GPU holds the full model.
Step 2: Forward Pass (Layer-by-Layer)
The forward pass proceeds one layer at a time, gathering the necessary weights for each computation.
1. Executing Layer 1 (h = w1*x1 + w2*x2
)
- Problem: To compute
h
, both GPUs needw1
andw2
, but only GPU 0 holds them. - Communication (
All-Gather
): Anall-gather
operation is triggered for the first layer's parameters.- GPU 0 sends its shard:
[2.0, -3.0]
. - GPU 1 has no parameters for this layer, so it contributes an empty tensor.
- After the operation, both GPUs now have a temporary, full copy of the needed weights:
[w1=2.0, w2=-3.0]
.
- GPU 0 sends its shard:
- Computation (
Forward Local
): Each GPU computes its local value forh
. These are the same calculations as in Stage 1.- GPU 0:
h = (2.0 * 1) + (-3.0 * 3) = -7.0
- GPU 1:
h = (2.0 * 2) + (-3.0 * 1) = 1.0
- GPU 0:
- Memory Management (
Free Full Weights and Activations
): The temporary full weight tensor[w1, w2]
is immediately discarded from memory on both GPUs. With activation checkpointing, they also discard the intermediate resulth
, only keeping the subsequent activationa = ReLU(h)
.- GPU 0:
a = ReLU(-7.0) = 0
, then discardsh
. - GPU 1:
a = ReLU(1.0) = 1.0
, then discardsh
.
- GPU 0:
2. Executing Layer 2 (y_pred = w3*a + w4
)
- Problem: Now the GPUs need
w3
andw4
, which are only on GPU 1. - Communication (
All-Gather
): A newall-gather
is performed for the second layer's parameters.- GPU 1 sends its shard:
[1.0, 0.5]
. - GPU 0 contributes an empty tensor.
- After the operation, both GPUs have a temporary copy of
[w3=1.0, w4=0.5]
.
- GPU 1 sends its shard:
- Computation (
Forward Local
): The final prediction and loss are calculated.- GPU 0:
-
y_pred = (1.0 * 0) + 0.5 = 0.5
-
Loss = 0.5 * (0.5 - 5)^2 = 10.125
-
- GPU 1:
-
y_pred = (1.0 * 1.0) + 0.5 = 1.5
-
Loss = 0.5 * (1.5 - 7)^2 = 15.125
-
- GPU 0:
- Memory Management (
Free Full Weights
): The weights[w3, w4]
and activationa
are discarded. At the end of the forward pass, no GPU holds a complete set of weights or intermediate activations.
Step 3: Backward Pass (Integrated Gradient Reduction and Activation Recomputation)
The backward pass mirrors the forward pass, but as gradients are computed for a layer, they are immediately reduced and scattered to their owning GPU. Before computing gradients for a layer, it must first recompute the necessary activations.
1. Backprop through Layer 2
- Activation Recomputation: To calculate gradients with respect to
w3
andw4
, we first need the activationa
. Since it was discarded, it must be recomputed by running a forward pass through Layer 1.- Communication (
All-Gather
): Anall-gather
is performed for Layer 1's parameters, bringing[w1=2.0, w2=-3.0]
to both GPUs. - Re-Computation: Each GPU recomputes
h
and thena
.- GPU 0:
h = -7.0
, recomputeda = 0
. - GPU 1:
h = 1.0
, recomputeda = 1.0
.
- GPU 0:
- Memory Management: The temporary
[w1, w2]
tensor used for recomputation is immediately discarded.
- Communication (
- Communication (
All-Gather
): To calculate gradients forw3
andw4
, we need the weights again. Anall-gather
brings[w3=1.0, w4=0.5]
to both GPUs. - Computation (
Backward Local
): Each GPU computes the local gradients forw3
andw4
using its recomputed activationa
.- GPU 0:
dL/dy_pred = 0.5 - 5 = -4.5
.-
dL/dw4 = -4.5 * 1 = -4.5
. -
dL/dw3 = -4.5 * a = -4.5 * 0 = 0
. - Local Gradients (GPU 0):
Grad_part_L2_0 = [grad_w3=0, grad_w4=-4.5]
.
-
- GPU 1:
dL/dy_pred = 1.5 - 7 = -5.5
.-
dL/dw4 = -5.5 * 1 = -5.5
. -
dL/dw3 = -5.5 * a = -5.5 * 1.0 = -5.5
. - Local Gradients (GPU 1):
Grad_part_L2_1 = [grad_w3=-5.5, grad_w4=-5.5]
.
-
- GPU 0:
- Communication (
Reduce-Scatter
): Areduce-scatter
is performed only on these Layer 2 gradients.- Reduce (Average):
-
g_avg3 = (0 + (-5.5)) / 2 = -2.75
-
g_avg4 = (-4.5 + (-5.5)) / 2 = -5.0
-
- Scatter: The result is sent to the owning GPU. Since GPU 1 owns
w3
andw4
, it receives the averaged gradient partition.- GPU 1 stores:
Grad_final_1 = [-2.75, -5.0]
.
- GPU 1 stores:
- Reduce (Average):
- Memory Management (
Free Full Weights and Activations
): The[w3, w4]
tensor and the recomputed activationa
are discarded from both GPUs.
2. Backprop through Layer 1
- Activation Recomputation: The input to Layer 1 is the original batch data (
x1
,x2
), which is still available in memory. Therefore, no activation recomputation is needed for this step. - Communication (
All-Gather
): The weights[w1=2.0, w2=-3.0]
are gathered again for the gradient calculation. - Computation (
Backward Local
): Each GPU computes local gradients forw1
andw2
.- GPU 0: The backpropagated gradient
dL/dh
was 0, so the local gradients are 0.- Local Gradients (GPU 0):
Grad_part_L1_0 = [grad_w1=0, grad_w2=0]
.
- Local Gradients (GPU 0):
- GPU 1: The backpropagated gradient
dL/dh
was -5.5.-
dL/dw1 = -5.5 * x1 = -5.5 * 2 = -11.0
. -
dL/dw2 = -5.5 * x2 = -5.5 * 1 = -5.5
. - Local Gradients (GPU 1):
Grad_part_L1_1 = [grad_w1=-11.0, grad_w2=-5.5]
.
-
- GPU 0: The backpropagated gradient
- Communication (
Reduce-Scatter
): Areduce-scatter
is performed on Layer 1 gradients.- Reduce (Average):
-
g_avg1 = (0 + (-11.0)) / 2 = -5.5
-
g_avg2 = (0 + (-5.5)) / 2 = -2.75
-
- Scatter: GPU 0 owns
w1
andw2
, so it receives the final gradient partition.- GPU 0 stores:
Grad_final_0 = [-5.5, -2.75]
.
- GPU 0 stores:
- Reduce (Average):
- Memory Management (
Free Full Weights
): The[w1, w2]
tensor is discarded.
Step 4: Optimizer Update (Perfectly Local)
Each GPU now has the exact gradients it needs for its local parameter partition. The optimizer step requires no communication.
-
On GPU 0 (updates w1, w2):
- Uses its stored gradient
Grad_final_0 = [-5.5, -2.75]
. - The calculations are identical to Stage 1, resulting in:
- New
w1'_fp32 = 2.1
. - New
w2'_fp32 = -2.9
.
- New
- New Parameter Partition on GPU 0:
W_part0 = [2.1, -2.9]
.
- Uses its stored gradient
-
On GPU 1 (updates w3, w4):
- Uses its stored gradient
Grad_final_1 = [-2.75, -5.0]
. - The calculations are identical to Stage 1, resulting in:
- New
w3'_fp32 = 1.1
. - New
w4'_fp32 = 0.6
.
- New
- New Parameter Partition on GPU 1:
W_part1 = [1.1, 0.6]
.
- Uses its stored gradient
Step 5: Synchronization
There is no final all-gather
step. The parameters are intended to live in their partitioned state. The "synchronization" was handled dynamically and incrementally throughout the forward and backward passes. The system is immediately ready for the next training iteration, starting again with the all-gather
for the first layer's weights.
The final state of the system is a fully sharded model with updated weights:
- GPU 0 holds:
[w1=2.1, w2=-2.9]
and their corresponding optimizer states. - GPU 1 holds:
[w3=1.1, w4=0.6]
and their corresponding optimizer states.
Top comments (1)
This is very helpful, thank youuu