Table of Contents
- Motivation
- What is tensor parallelism
- When do we implement tensor parallelism
- Setup
- Forward pass
- Why split Matrix A by columns, Matrix B by rows
- Backward pass (backpropagation)
- Communication cost of tensor parallelism
- Code
Motivation
Tensor parallelism is widely used in training and deploying massive neural networks such as LLMs across multiple GPUs. While libraries such as DeepSpeed provide out-of-the-box solutions for tensor parallelism, understanding how to implement tensor parallelism by hand may help with:
- Building intuition: instead of just knowing tensor parallelism splits tensors, you will understand how the choice of splitting by "columns" in one layer and "rows" in the next is essential for the math to work out.
- Understanding performance bottlenecks: You will gain an understanding why tensor parallelism is communication intensive and generally recommended to be done over high-speed interconnects such as NVLink.
- Debugging: An intuitive mental model of tensor parallelism may be useful for debugging when distributed training jobs fail and the error messages are cryptic.
This article can also serve as a complementary study guide for Stanford's CS336 Lecture 7 on Parallelism. This article was written with assistance of Gemini 2.5 Pro.
What is tensor parallelism
When a model becomes too massive to fit into a single GPU's memory, tensor parallelism offers a way to split up the workload. Tensor parallelism accomplishes two goals simultaneously:
- Distributes Memory Load: It enables the combining of the memory of several GPUs to store a single, massive layer that would be too large for any one device.
- Parallelizes Computation: Each GPU performs a smaller portion of the matrix multiplication using its local shard of the weight matrix. This allows the GPUs to work on the same forward or backward pass of a single layer at the same time, speeding up the computation.
When do we implement tensor parallelism
Tensor parallelism comes at a significant cost of communication overhead. This article will demonstrate why tensor parallelism requires a high-volume, synchronous communication step (All-Reduce
) twice for every single block in the model—once during the forward pass and once during the backward pass. This creates a fundamental trade-off:
- Pipeline Parallelism: In pipeline parallelism, you split groups of layers (stages) across GPUs. Communication is less frequent, happening only at the boundaries between stages. However, pipeline parallelism introduces a "bubble" of idle time at the beginning and end of each batch as the pipeline fills and drains, reducing hardware utilization.
- Tensor Parallelism: In tensor parallelism, you split the work within each layer. This helps to keep all GPUs perfectly synchronized and busy, completely eliminating the pipeline "bubble." But the price for this high utilization is constant, layer-by-layer communication across all participating GPUs.
The intense communication pattern of tensor parallelism over pipeline parallelism leads to a rule of thumb for designing large-scale training setups:
- Use Tensor Parallelism for Intra-Node Scaling: When you have a group of GPUs (e.g., the 8 A100s in a DGX node) connected with a high-speed backplane, tensor parallelism is the ideal way to scale up to handle massive layers.
- Use Pipeline and Data Parallelism for Inter-Node Scaling: To scale beyond a single node, you typically combine tensor parallelism with pipeline parallelism, where each stage of the pipeline is a tensor-parallel group of GPUs. This minimizes the slow communication across nodes to only what is necessary between pipeline stages.
Setup
We will walk through the forward pass of a two-layer Multi-Layer Perceptron (MLP) block found in a Transformer, which consists of two main operations:
- Y = GeLU(X * A)
- Z = Dropout(Y * B)
We'll assume we have two GPUs, GPU 1 and GPU 2.
Source: Stanford CS 336 Language Modeling from Scratch, Lecture 7: Parallelism 1, 58 min 12 sec
Forward pass
We first begin with the forward pass, and the next section will walk through the backpropagation.
The forward pass in a neural network is the process of taking an input and propagating it through the layers of the model to generate a final output. In the context of tensor parallelism, this process is modified because the model's large weight matrices are too big for a single GPU and are instead sharded, or split, across multiple GPUs.
To execute a matrix multiplication, each GPU performs a smaller calculation in parallel using its local shard of the weights. This often follows a specific pattern where one layer's weights are split by columns (requiring no communication after the multiplication) and the next layer's weights are split by rows. This row-parallel approach produces partial results on each GPU that must then be aggregated through a collective communication operation, such as an All-Reduce, to ensure the final activation tensor is mathematically equivalent to the result that would have been computed on a single, monolithic GPU.
Step 1: Initial Setup - Defining and Splitting Matrices
First, let's define our input matrix X
and the weight matrices A
and B
for our two linear layers. We will use small, simple numbers for clarity.
Let our input X
be a 2x4 matrix:
Let the first weight matrix A
be a 4x6 matrix:
Let the second weight matrix B
be a 6x4 matrix:
Now, we split these matrices across our two GPUs.
- Matrix
A
is split vertically (by columns). GPU 1 gets the first 3 columns (A1
), and GPU 2 gets the last 3 columns (A2
). This is known as column parallelism.
(on GPU 1)
(on GPU 2)
- Matrix
B
is split horizontally (by rows). GPU 1 gets the first 3 rows (B1
), and GPU 2 gets the last 3 rows (B2
). This is known as row parallelism.
(on GPU 1)
(on GPU 2)
We split Matrix A by columns, and Matrix B by rows, so as to minimise communication costs incurred by the all-reduce operation. I will go into more detail at the next section on why we split Matrix A by columns, and Matrix B by rows.
Step 2: Forward Pass for Y = GeLU(X * A)
This operation is parallelized using the column-sharded matrix A
. The input X
is broadcast to both GPUs.
On GPU 1: Calculate Y_intermediate_1 = X * A1
We perform matrix multiplication. The output will be a 2x3 matrix.
Calculating each element:
-
Y_intermediate_1
= (1*1) + (2*7) + (3*4) + (4*1) = 1 + 14 + 12 + 4 = 31 -
Y_intermediate_1
= (1*2) + (2*8) + (3*5) + (4*2) = 2 + 16 + 15 + 8 = 41 -
Y_intermediate_1
= (1*3) + (2*9) + (3*6) + (4*3) = 3 + 18 + 18 + 12 = 51 -
Y_intermediate_1
= (5*1) + (6*7) + (7*4) + (8*1) = 5 + 42 + 28 + 8 = 83 -
Y_intermediate_1
= (5*2) + (6*8) + (7*5) + (8*2) = 10 + 48 + 35 + 16 = 109 -
Y_intermediate_1
= (5*3) + (6*9) + (7*6) + (8*3) = 15 + 54 + 42 + 24 = 135
So, the result of the matrix multiplication on GPU 1 is:
On GPU 2: Calculate Y_intermediate_2 = X * A2
Similarly, we perform the multiplication on GPU 2.
Calculating each element:
-
Y_intermediate_2
= (1*4) + (2*1) + (3*7) + (4*4) = 4 + 2 + 21 + 16 = 43 -
Y_intermediate_2
= (1*5) + (2*2) + (3*8) + (4*5) = 5 + 4 + 24 + 20 = 53 -
Y_intermediate_2
= (1*6) + (2*3) + (3*9) + (4*6) = 6 + 6 + 27 + 24 = 63 -
Y_intermediate_2
= (5*4) + (6*1) + (7*7) + (8*4) = 20 + 6 + 49 + 32 = 107 -
Y_intermediate_2
= (5*5) + (6*2) + (7*8) + (8*5) = 25 + 12 + 56 + 40 = 133 -
Y_intermediate_2
= (5*6) + (6*3) + (7*9) + (8*6) = 30 + 18 + 63 + 48 = 159
So, the result of the matrix multiplication on GPU 2 is:
Apply GeLU Activation Function
Next, the Gaussian Error Linear Unit (GeLU) activation function is applied element-wise to the intermediate results on each GPU. The GeLU function is defined as:
where is the Cumulative Distribution Function (CDF) for the standard normal distribution. A common and precise formula is:
For simplicity in this manual example, let's assume GeLU(x) ≈ x
for large positive x
and GeLU(x) ≈ 0
for large negative x
. Since all our numbers are large and positive, we'll approximate GeLU(x) = x
. In a real scenario, the exact formula would be applied.
(Note: The erf
function produces an "S"-shaped sigmoid curve that passes through the origin. As the input x
gets very large and positive, erf(x)
gets very close to 1. As the input gets very large and negative, erf(x)
gets very close to -1.)
So, the outputs on each GPU are:
(on GPU 1)
(on GPU 2)
These two matrices, Y1
and Y2
, represent the sharded output of the first layer. They remain on their respective GPUs and become the input for the next step.
Step 3: Forward Pass for Z = Dropout(Y * B)
This operation is parallelized using the row-sharded matrix B
. The inputs Y1
and Y2
are already on the correct GPUs.
On GPU 1: Calculate Z_partial_1 = Y1 * B1
We multiply the local activation Y1
with the local weight shard B1
.
Calculating each element:
-
Z_partial_1
= (31*6) + (41*2) + (51*9) = 186 + 82 + 459 = 727 -
Z_partial_1
= (31*5) + (41*1) + (51*8) = 155 + 41 + 408 = 604 -
Z_partial_1
= (31*4) + (41*7) + (51*6) = 124 + 287 + 306 = 717 -
Z_partial_1
= (31*3) + (41*8) + (51*5) = 93 + 328 + 255 = 676 -
Z_partial_1
= (83*6) + (109*2) + (135*9) = 498 + 218 + 1215 = 1931 -
Z_partial_1
= (83*5) + (109*1) + (135*8) = 415 + 109 + 1080 = 1604 -
Z_partial_1
= (83*4) + (109*7) + (135*6) = 332 + 763 + 810 = 1905 -
Z_partial_1
= (83*3) + (109*8) + (135*5) = 249 + 872 + 675 = 1796
Result on GPU 1:
On GPU 2: Calculate Z_partial_2 = Y2 * B2
Simultaneously, on GPU 2:
Calculating each element:
-
Z_partial_2
= (43*4) + (53*1) + (63*5) = 172 + 53 + 315 = 540 -
Z_partial_2
= (43*3) + (53*2) + (63*6) = 129 + 106 + 378 = 613 -
Z_partial_2
= (43*2) + (53*3) + (63*7) = 86 + 159 + 441 = 686 -
Z_partial_2
= (43*1) + (53*4) + (63*8) = 43 + 212 + 504 = 759 -
Z_partial_2
= (107*4) + (133*1) + (159*5) = 428 + 133 + 795 = 1356 -
Z_partial_2
= (107*3) + (133*2) + (159*6) = 321 + 266 + 954 = 1541 -
Z_partial_2
= (107*2) + (133*3) + (159*7) = 214 + 399 + 1113 = 1726 -
Z_partial_2
= (107*1) + (133*4) + (159*8) = 107 + 532 + 1272 = 1911
Result on GPU 2:
All-Reduce Communication Step
The logic behind splitting B
by rows is that the final output Z
is the sum of the partial results from each GPU. This requires a communication step called an All-Reduce. Each GPU sends its partial result to the others and receives their partial results, summing them all together.
Adding each element:
After the All-Reduce, the complete Z_unactivated
matrix is present on both GPUs.
Apply Dropout
Dropout is a regularization technique used during training to prevent overfitting. It randomly sets a fraction of the output elements to zero at each update step.
Let's assume a dropout rate of 50%. We create a "dropout mask" with the same dimensions as Z_unactivated
, where each position is randomly 0 or 1.
Example dropout mask:
We multiply our Z_unactivated
by this mask:
Finally, to ensure the expected sum of outputs remains the same, the remaining non-zero elements are scaled up by a factor of 1 / (1 - dropout_rate)
. In our case, this is 1 / (1 - 0.5) = 2
. This is called "inverted dropout".
This final matrix Z
is the output of our tensor-parallel multi-layer perceptron block.
Summary of the forward pass
- Start: We have an input
X
and two GPUs. - Layer 1 (Column Parallelism):
- The weight matrix
A
is split into columns (A1
,A2
). - Each GPU calculates a part of the output:
X * A1
on GPU 1 andX * A2
on GPU 2. - An activation function (GeLU) is applied independently on each GPU to these partial results (
Y1
,Y2
). No communication is needed.
- The weight matrix
- Layer 2 (Row Parallelism):
- The weight matrix
B
is split into rows (B1
,B2
). - Each GPU multiplies its local activation from the previous step with its local weight shard:
Y1 * B1
on GPU 1 andY2 * B2
on GPU 2. - An All-Reduce operation sums the partial results from all GPUs. This is the critical communication step.
- The final operation (Dropout) is applied to the complete, summed tensor, resulting in the final output
Z
.
- The weight matrix
Why split Matrix A by columns, Matrix B by rows
The goal is to sequence the operations to ensure the expensive All-Reduce
step (where GPUs must synchronize and sum their results) happens as infrequently as possible. To understand this, let's define the two operations and their communication costs.
1. Column Parallelism (What we do to A
)
When we split a weight matrix A
by columns, the input X
is broadcast to all GPUs. Each GPU then calculates a vertical "slice" of the output.
- Mathematics:
Y = X * [A1, A2] = [X * A1, X * A2] = [Y1, Y2]
- Result: The output
Y
is naturally sharded (split) across the GPUs.Y1
is on GPU 1 andY2
is on GPU 2. - Communication Cost (Forward Pass): Zero. No communication is needed to produce the sharded output
Y
. Each GPU works independently. The output is left "as is" on each device, ready for the next step.
2. Row Parallelism (What we do to B
)
When we split a weight matrix B
by rows, the input must already be sharded in a corresponding way (which is exactly what Column Parallelism just produced!). Each GPU multiplies its local input shard with its local weight shard.
- Mathematics:
Z = Y * B = [Y1, Y2] * [B1; B2] = (Y1 * B1) + (Y2 * B2)
(Note:;
denotes stacking rows) - Result: Each GPU calculates a partial result. To get the final, correct output
Z
, these partial results must be summed together. - Communication Cost (Forward Pass): One
All-Reduce
operation. This is the expensive step whereZ_partial_1
andZ_partial_2
are added across all GPUs.
Analyzing the Standard Column -> Row
Pattern
Now, let's see how these two pieces fit together in our MLP block:
-
First Layer:
Y = GeLU(X * A)
(Column Parallelism)- We split
A
by columns. - The multiplication
X * A
produces a sharded activation[Y1, Y2]
without any need for communication. - Forward Communication Cost: 0
- We split
-
Second Layer:
Z = Dropout(Y * B)
(Row Parallelism)- We split
B
by rows. - The sharded activation
[Y1, Y2]
from the previous step is the perfect input for this layer. - The multiplication
Y * B
produces partial resultsZ_partial_1
andZ_partial_2
. - To get the final output
Z
, we must perform anAll-Reduce
to sum the partial results. - Forward Communication Cost: 1 All-Reduce
- We split
By sequencing the operations this way (Column -> Row
), we have successfully computed the output for a full two-layer block with only one communication step.
What Happens If We Switch? (Row -> Column)
Let's imagine we tried to do it the other way around.
-
First Layer:
Y = GeLU(X * A)
(Row Parallelism)- We split
A
by rows (A1
,A2
). - For the math
X * A
to work, we would also have to split the inputX
by columns (X1
,X2
). - The operation becomes
(X1 * A1) + (X2 * A2)
. - This produces partial results which must be summed.
- Forward Communication Cost: 1 All-Reduce. The output
Y
is now a complete (un-sharded) tensor, identical on all GPUs.
- We split
-
Second Layer:
Z = Dropout(Y * B)
(Column Parallelism)- We split
B
by columns (B1
,B2
). - The input
Y
is a complete tensor from the previous step. - The operation
Y * [B1, B2]
produces a sharded output[Z1, Z2]
without communication. - Forward Communication Cost: 0
- We split
Why column then row?
As you can see, both patterns (Column -> Row
and Row -> Column
) result in the exact same amount of communication: one All-Reduce
per two-layer block.
So why is the Column -> Row
pattern the standard convention used in frameworks like Megatron-LM?
The reason is modularity and efficiency in the computational graph. The standard Transformer block consists of an Attention block followed by an MLP block. The Column -> Row
pattern creates a perfect compositional unit:
- The MLP block takes a complete tensor as input (like
X
). - It performs
Column Parallel -> Row Parallel
. - It outputs a complete tensor (
Z
, after the All-Reduce).
This complete tensor Z
can then be fed directly into the next Attention block, which will also start with a column-parallel operation. This creates a clean, predictable, and efficient computational flow where the "sharded" nature of the activations is kept internal to the block, and communication is neatly packaged.
In short, the Column -> Row
sequence is a deliberate design pattern that cleverly delays the necessary communication to the very last step within a block, creating a modular and efficient structure for building large neural networks.
Backward pass (backpropagation)
The goal of the backward pass is to calculate the gradients of the loss function with respect to all the learnable parameters (the weights A
and B
) and the input (X
). These gradients tell us how to adjust the weights to reduce the loss.
The process starts with the gradient of the loss with respect to the final output, which we'll call dL/dZ
or simply dZ
. This initial gradient propagates backward through the network. For this example, we'll assume the process has already begun and the initial gradient dZ
is a 2x4 matrix of ones.
Initial Gradient (assumed):
Step 1: Backward Pass through Dropout
The first step is to backpropagate through the Dropout layer. The gradient is only allowed to pass through the neurons that were not "dropped out" during the forward pass. This means we apply the same dropout mask and scaling factor.
Mathematics
The gradient with respect to the pre-dropout output (Z_unactivated
) is calculated by element-wise multiplying the incoming gradient dZ
with the dropout mask and scaling it by the same factor used in the forward pass.
Calculation
From our forward pass, our mask and scaling factor were:
, Scale = 2
The gradient dZ
is passed then through this inverted dropout layer:
This gradient, dZ_unactivated
, is now available on both GPUs, as the All-Reduce in the forward pass made the output Z_unactivated
identical on both.
Step 2: Backward Pass for Z = Dropout(Y * B)
Now we backpropagate through the second linear layer. This involves calculating the gradients with respect to the weights (dB
) and the activations from the previous layer (dY
).
Mathematics for Gradients dY
and dB
Given the operation Z_unactivated = Y * B
, the gradients are:
- Gradient w.r.t. Y:
- Gradient w.r.t. B:
In our tensor parallel setup (row parallelism), Y = [Y1, Y2]
and B
is split into B1
and B2
. The forward pass involved local multiplications (Y1*B1
, Y2*B2
) followed by an All-Reduce. The backward pass reverses this.
- For
dY
(Activation Gradients): The gradientdZ_unactivated
is multiplied by the transpose of the local weight shard. This is a local computation on each GPU and requires no communication.- On GPU 1:
- On GPU 2:
- For
dB
(Weight Gradients): The transpose of the local activationY
is multiplied by the incoming gradient.- On GPU 1:
- On GPU 2:
Calculation for dY
On GPU 1: Calculate dY1 = dZ_unactivated * B1^T
On GPU 2: Calculate dY2 = dZ_unactivated * B2^T
Calculation for dB
On GPU 1: Calculate dB1 = Y1^T * dZ_unactivated
On GPU 2: Calculate dB2 = Y2^T * dZ_unactivated
Step 3: Backward Pass through GeLU Activation
Next, the gradients dY1
and dY2
are passed backward through the GeLU activation function.
Mathematics
The chain rule dictates that we multiply the incoming gradient by the derivative of the activation function evaluated at its original input.
The derivative of GeLU is . As in the forward pass, for the large positive numbers in our example, this derivative is very close to 1. We will use this approximation for simplicity.
Calculation
We simply pass the gradients through, as multiplying by 1 does not change them.
On GPU 1:
On GPU 2:
Step 4: Backward Pass for Y = GeLU(X * A)
Finally, we backpropagate through the first linear layer. This computes the gradients dA
(for the weights) and dX
(for the input).
Mathematics for Gradients dX
and dA
Given Y_intermediate = X * A
, the gradients are:
- Gradient w.r.t. X:
- Gradient w.r.t. A:
In our tensor parallel setup (column parallelism), the input X
was broadcast, and the output Y_intermediate
was sharded (Y_intermediate_1
, Y_intermediate_2
).
- For
dA
(Weight Gradients): The calculation is local.- On GPU 1:
- On GPU 2:
- For
dX
(Input Gradients): Each GPU calculates a partial gradient fordX
. These partial gradients must be summed together using an All-Reduce operation.- On GPU 1:
- On GPU 2:
- Final Gradient:
Calculation for dA
On GPU 1: Calculate dA1 = X^T * dY_intermediate_1
On GPU 2: Calculate dA2 = X^T * dY_intermediate_2
Calculation for dX
(with All-Reduce)
On GPU 1: Calculate dX_partial_1 = dY_intermediate_1 * A1^T
On GPU 2: Calculate dX_partial_2 = dY_intermediate_2 * A2^T
All-Reduce Communication Step
Finally, the partial gradients for X are summed across GPUs.
This final dX
is the gradient with respect to the input of the MLP block, ready to be passed further backward to the previous layer in the network. The weight gradients dA1
, dA2
, dB1
, and dB2
are used by the optimizer to update the model's weights on each respective GPU.
The Duality of Communication in Forward vs. Backward Pass
A fundamental principle of backpropagation in tensor parallelism is that the communication patterns are reversed.
- If a layer requires an
All-Reduce
in the forward pass, its corresponding backward pass for the input gradient requires no communication (Identity
). - If a layer requires no communication (
Identity
) in the forward pass, its corresponding backward pass for the input gradient requires anAll-Reduce
.
Here is a simple table summarizing the communication for the input gradients (dX
and dY
):
Layer Type (Forward) | Forward Pass Communication (f ) |
Backward Pass Communication (g ) |
---|---|---|
Column Parallelism |
Identity (No communication) |
All-Reduce (Sum partial results) |
Row Parallelism |
All-Reduce (Sum partial results) |
Identity (No communication) |
(Note: The gradients for the weights, dA
and dB
, are always local calculations and never require an All-Reduce, so we can focus on the activation gradients.)
Backpropagation Through the Standard Column -> Row
Pattern
Now, let's trace the backward pass through our standard two-layer block, keeping the table above in mind. The process starts with dZ
, the gradient from the layer above. Because the forward pass ended with an All-Reduce
, Z
was a complete tensor on all GPUs, which means dZ
is also a complete tensor.
-
Backprop through Layer 2:
Z = Dropout(Y * B)
(Row Parallelism)- Forward Pass required:
All-Reduce
. - Backward Pass requires:
Identity
(No communication). - Calculation (
dY = dZ * B^T
): The complete gradientdZ
is multiplied by the local weight shardB^T
on each GPU. This directly produces the correct sharded activation gradientdY
([dY1, dY2]
) without any need for communication.
- Forward Pass required:
-
Backprop through Layer 1:
Y = GeLU(X * A)
(Column Parallelism)- Forward Pass required:
Identity
. - Backward Pass requires:
All-Reduce
. - Calculation (
dX = dY * A^T
): The sharded gradientdY
from the previous step is multiplied by the local weight shardA^T
. This produces a partial gradient for the input,dX_partial
. To get the final, complete gradientdX
, these partial results must be summed across all GPUs. This is theAll-Reduce
step.
- Forward Pass required:
As you can see, the backward pass for our Column -> Row
block has exactly one All-Reduce
operation, just like the forward pass.
What if we used Row -> Column
?
The result would be perfectly symmetric:
- Backprop through Layer 2 (Column Parallelism): The incoming gradient
dZ
would be sharded. This step would require anAll-Reduce
to produce a complete gradientdY
. - Backprop through Layer 1 (Row Parallelism): The complete gradient
dY
would be the input. This step would require no communication (Identity
) to produce the sharded input gradientdX
.
The total communication remains one All-Reduce
in the forward pass and one in the backward pass.
The Preference is for Modularity, Not Cost
The total cost is the same. So why do we prefer Column -> Row
?
The reason is that this pattern creates a perfectly self-contained, modular block.
- Input: The block takes a complete tensor (
X
). - Internal State: It creates a sharded tensor (
Y
) internally. - Output: It performs an
All-Reduce
at the end and outputs a complete tensor (Z
).
The backward pass mirrors this perfectly:
- Input: It takes a complete gradient (
dZ
). - Internal State: It creates a sharded gradient (
dY
) internally. - Output: It performs an
All-Reduce
at the end and outputs a complete gradient (dX
).
This means you can stack these Transformer blocks one after another without worrying about whether the tensor you're passing is complete or sharded. Each block handles its own internal sharding and communication, always presenting a clean, complete tensor at its boundaries. This makes the overall architecture of a large model much simpler to design, implement, and debug.
Communication cost of tensor parallelism
In our walkthrough, communication between GPUs did not happen after every mathematical operation. It only occurred at very specific points where partial results needed to be combined to form a complete tensor. Let's pinpoint them:
Forward Pass Communication: This happened at the end of Step 3. To calculate the final
Z_unactivated
, we needed to sum the partial results from each GPU.
This summation is anall-reduce
operation. GPU 1 sends itsZ_partial_1
and receivesZ_partial_2
; GPU 2 sendsZ_partial_2
and receivesZ_partial_1
. Both then compute the sum.Backward Pass Communication: This happened at the end of Step 4. To calculate the final input gradient
dX
, we needed to sum the partial gradients from each GPU.
This summation is also anall-reduce
operation.
For a single pass through this two-layer block, we perform two full all-reduce
operations.
Analyzing the Communication Volume
The formula for the communication volume in tensor parallelism per layer:
Source: Stanford CS 336 Language Modeling from Scratch, Lecture 7: Parallelism 1, 1 hour, 1 min 8 sec
-
b
: This is the batch size. In our example, the input matrixX
had 2 rows, so . -
s
: This is the sequence length. In a simple MLP, we can consider this to be 1, so . -
h
: This is the hidden size of the tensor being communicated. The tensor we communicated in the forward pass wasZ_partial
, which had a shape of(2, 4)
. The size of the last dimension is 4, so . -
n_devices
: The number of GPUs, which is .
In short, bsh
represents the total number of elements in the activation tensor being communicated.
Volume of data in one all-reduce
:
In the forward pass all-reduce
, the tensor being communicated is Z_partial
. The number of elements in this tensor is:
The all-reduce
operation is a collective communication where each GPU sends its 8 elements and receives 8 elements from its peer.
The term
in the formula characterizes the volume of a standard ring all-reduce algorithm. For our 2 GPUs, this factor is:
This signifies that each GPU sends and receives data chunks in a way that scales with the number of devices.
The factor of 8
in the formula often accounts for two things: 4
bytes for a standard 32-bit floating-point number and a factor of 2
because communication happens in both the forward and backward passes.
Why the communication cost of tensor parallelism is high
The communication cost of tensor parallelism is high due to the combination of volume and frequency:
High-Cost Operation: The
all-reduce
is a "collective" operation that is much more complex and latency-sensitive than a simple point-to-point send/receive. It requires synchronization across all GPUs in the group.Large Data Volume: The amount of data communicated in each
all-reduce
is the full activation tensor (bsh
), which can be very large in modern LLMs.High Frequency: This expensive, high-volume operation is not performed once, but twice for every single transformer block in the model during a full training step (forward + backward pass).
Therefore, while tensor parallelism is excellent at keeping all GPUs utilized (eliminating the "bubble" of pipeline parallelism), it comes at the cost of constant, high-volume communication across all participating devices. This is why the general recommendation is to use tensor parallel whenever we have low-latency, high-bandwidth interconnects like NVIDIA's NVLink, which are specifically designed to handle this intense communication pattern efficiently.
Code
For those who prefer following the code, the walk through guide for both the forward and backward pass are replicated below using Python.
Setup
import torch
import torch.nn.functional as F
# --------------------------------------------------------------------------
# 1. SETUP: Define tensors from the manual example
# --------------------------------------------------------------------------
# Set requires_grad=True to track gradients
X = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8]], dtype=torch.float32, requires_grad=True)
A = torch.tensor([[1, 2, 3, 4, 5, 6],
[7, 8, 9, 1, 2, 3],
[4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 5, 6]], dtype=torch.float32, requires_grad=True)
B = torch.tensor([[6, 5, 4, 3],
[2, 1, 7, 8],
[9, 8, 6, 5],
[4, 3, 2, 1],
[1, 2, 3, 4],
[5, 6, 7, 8]], dtype=torch.float32, requires_grad=True)
# NOTE: For verification, we use an identity function to match the manual
# calculation's approximation of GeLU(x) ≈ x. In a real model, you would
# use F.gelu(x).
gelu = lambda x: x
# Define the fixed dropout mask from the manual example
dropout_mask = torch.tensor([[1, 0, 0, 1],
[0, 1, 1, 0]], dtype=torch.float32)
dropout_rate = 0.5
dropout_scale = 1 / (1 - dropout_rate)
print("--- Initial Tensors ---")
print(f"X:\n{X}\n")
print(f"A:\n{A}\n")
print(f"B:\n{B}\n")
Output
--- Initial Tensors ---
X:
tensor([[1., 2., 3., 4.],
[5., 6., 7., 8.]], requires_grad=True)
A:
tensor([[1., 2., 3., 4., 5., 6.],
[7., 8., 9., 1., 2., 3.],
[4., 5., 6., 7., 8., 9.],
[1., 2., 3., 4., 5., 6.]], requires_grad=True)
B:
tensor([[6., 5., 4., 3.],
[2., 1., 7., 8.],
[9., 8., 6., 5.],
[4., 3., 2., 1.],
[1., 2., 3., 4.],
[5., 6., 7., 8.]], requires_grad=True)
Forward pass
# --------------------------------------------------------------------------
# 2. FORWARD PASS
# --------------------------------------------------------------------------
print("--- FORWARD PASS ---")
# --- Simulate Tensor Parallelism on 2 GPUs ---
# Split A by columns (Column Parallelism)
# "GPU 1" gets A1, "GPU 2" gets A2
A1 = A[:, :3]
A2 = A[:, 3:]
# Split B by rows (Row Parallelism)
# "GPU 1" gets B1, "GPU 2" gets B2
B1 = B[:3, :]
B2 = B[3:, :]
# --- Layer 1: Y = GeLU(X * A) ---
# Operation is local to each GPU, no communication needed.
# "GPU 1" computes its part
Y_intermediate_1 = X @ A1
Y1 = gelu(Y_intermediate_1)
# "GPU 2" computes its part
Y_intermediate_2 = X @ A2
Y2 = gelu(Y_intermediate_2)
print(f"Y1 (on GPU 1):\n{Y1}\n")
print(f"Y2 (on GPU 2):\n{Y2}\n")
# --- Layer 2: Z = Dropout(Y * B) ---
# Local matrix multiplication on each GPU
# Note: Y is implicitly [Y1, Y2]
Z_partial_1 = Y1 @ B1
Z_partial_2 = Y2 @ B2
print(f"Z_partial_1 (on GPU 1):\n{Z_partial_1}\n")
print(f"Z_partial_2 (on GPU 2):\n{Z_partial_2}\n")
# *** All-Reduce Communication Step ***
# The partial results are summed across GPUs.
Z_unactivated = Z_partial_1 + Z_partial_2
print(f"Z_unactivated (after All-Reduce):\n{Z_unactivated}\n")
# Apply Dropout
Z_masked = Z_unactivated * dropout_mask
Z = Z_masked * dropout_scale
print(f"Final Output Z:\n{Z}\n")
Output
--- FORWARD PASS ---
Y1 (on GPU 1):
tensor([[ 31., 41., 51.],
[ 83., 109., 135.]], grad_fn=<MmBackward0>)
Y2 (on GPU 2):
tensor([[ 43., 53., 63.],
[107., 133., 159.]], grad_fn=<MmBackward0>)
Z_partial_1 (on GPU 1):
tensor([[ 727., 604., 717., 676.],
[1931., 1604., 1905., 1796.]], grad_fn=<MmBackward0>)
Z_partial_2 (on GPU 2):
tensor([[ 540., 613., 686., 759.],
[1356., 1541., 1726., 1911.]], grad_fn=<MmBackward0>)
Z_unactivated (after All-Reduce):
tensor([[1267., 1217., 1403., 1435.],
[3287., 3145., 3631., 3707.]], grad_fn=<AddBackward0>)
Final Output Z:
tensor([[2534., 0., 0., 2870.],
[ 0., 6290., 7262., 0.]], grad_fn=<MulBackward0>)
Backward pass
# --------------------------------------------------------------------------
# 3. BACKWARD PASS
# --------------------------------------------------------------------------
print("--- BACKWARD PASS ---")
# Define a dummy loss. Z.sum() creates an initial gradient dZ of all ones.
loss = Z.sum()
print(f"Loss (for backprop):\n{loss}\n")
# Automatically compute gradients for all tensors with requires_grad=True
loss.backward()
Output
--- BACKWARD PASS ---
Loss (for backprop):
18956.0
Verification
# --------------------------------------------------------------------------
# 4. VERIFICATION: Check gradients against manual calculations
# --------------------------------------------------------------------------
print("--- VERIFYING GRADIENTS ---\n")
# Gradient for X (dX)
# This required an All-Reduce in the backward pass
print(f"Gradient for X (dX):\n{X.grad}\n")
# Gradients for A (dA1 and dA2)
# A.grad contains the gradients for the entire A matrix. We can verify
# by slicing it into the parts corresponding to A1 and A2.
print(f"Gradient for A1 (dA1):\n{A.grad[:, :3]}\n")
print(f"Gradient for A2 (dA2):\n{A.grad[:, 3:]}\n")
# Gradients for B (dB1 and dB2)
# Similarly, B.grad contains the full gradient.
print(f"Gradient for B1 (dB1):\n{B.grad[:3, :]}\n")
print(f"Gradient for B2 (dB2):\n{B.grad[3:, :]}\n")
Output
--- VERIFYING GRADIENTS ---
Gradient for X (dX):
tensor([[388., 646., 724., 388.],
[380., 614., 704., 380.]])
Gradient for A1 (dA1):
tensor([[108., 100., 168.],
[144., 136., 224.],
[180., 172., 280.],
[216., 208., 336.]])
Gradient for A2 (dA2):
tensor([[ 60., 60., 156.],
[ 80., 80., 208.],
[100., 100., 260.],
[120., 120., 312.]])
Gradient for B1 (dB1):
tensor([[ 62., 166., 166., 62.],
[ 82., 218., 218., 82.],
[102., 270., 270., 102.]])
Gradient for B2 (dB2):
tensor([[ 86., 214., 214., 86.],
[106., 266., 266., 106.],
[126., 318., 318., 126.]])
Top comments (0)