DEV Community

Cover image for Tensor parallelism by hand
Lewis Won
Lewis Won

Posted on

Tensor parallelism by hand

Table of Contents

Motivation

Tensor parallelism is widely used in training and deploying massive neural networks such as LLMs across multiple GPUs. While libraries such as DeepSpeed provide out-of-the-box solutions for tensor parallelism, understanding how to implement tensor parallelism by hand may help with:

  • Building intuition: instead of just knowing tensor parallelism splits tensors, you will understand how the choice of splitting by "columns" in one layer and "rows" in the next is essential for the math to work out.
  • Understanding performance bottlenecks: You will gain an understanding why tensor parallelism is communication intensive and generally recommended to be done over high-speed interconnects such as NVLink.
  • Debugging: An intuitive mental model of tensor parallelism may be useful for debugging when distributed training jobs fail and the error messages are cryptic.

This article can also serve as a complementary study guide for Stanford's CS336 Lecture 7 on Parallelism. This article was written with assistance of Gemini 2.5 Pro.


What is tensor parallelism

When a model becomes too massive to fit into a single GPU's memory, tensor parallelism offers a way to split up the workload. Tensor parallelism accomplishes two goals simultaneously:

  1. Distributes Memory Load: It enables the combining of the memory of several GPUs to store a single, massive layer that would be too large for any one device.
  2. Parallelizes Computation: Each GPU performs a smaller portion of the matrix multiplication using its local shard of the weight matrix. This allows the GPUs to work on the same forward or backward pass of a single layer at the same time, speeding up the computation.

When do we implement tensor parallelism

Tensor parallelism comes at a significant cost of communication overhead. This article will demonstrate why tensor parallelism requires a high-volume, synchronous communication step (All-Reduce) twice for every single block in the model—once during the forward pass and once during the backward pass. This creates a fundamental trade-off:

  • Pipeline Parallelism: In pipeline parallelism, you split groups of layers (stages) across GPUs. Communication is less frequent, happening only at the boundaries between stages. However, pipeline parallelism introduces a "bubble" of idle time at the beginning and end of each batch as the pipeline fills and drains, reducing hardware utilization.
  • Tensor Parallelism: In tensor parallelism, you split the work within each layer. This helps to keep all GPUs perfectly synchronized and busy, completely eliminating the pipeline "bubble." But the price for this high utilization is constant, layer-by-layer communication across all participating GPUs.

The intense communication pattern of tensor parallelism over pipeline parallelism leads to a rule of thumb for designing large-scale training setups:

  • Use Tensor Parallelism for Intra-Node Scaling: When you have a group of GPUs (e.g., the 8 A100s in a DGX node) connected with a high-speed backplane, tensor parallelism is the ideal way to scale up to handle massive layers.
  • Use Pipeline and Data Parallelism for Inter-Node Scaling: To scale beyond a single node, you typically combine tensor parallelism with pipeline parallelism, where each stage of the pipeline is a tensor-parallel group of GPUs. This minimizes the slow communication across nodes to only what is necessary between pipeline stages.

Setup

We will walk through the forward pass of a two-layer Multi-Layer Perceptron (MLP) block found in a Transformer, which consists of two main operations:

  1. Y = GeLU(X * A)
  2. Z = Dropout(Y * B)

We'll assume we have two GPUs, GPU 1 and GPU 2.

parallelism diagram

Source: Stanford CS 336 Language Modeling from Scratch, Lecture 7: Parallelism 1, 58 min 12 sec


Forward pass

We first begin with the forward pass, and the next section will walk through the backpropagation.

The forward pass in a neural network is the process of taking an input and propagating it through the layers of the model to generate a final output. In the context of tensor parallelism, this process is modified because the model's large weight matrices are too big for a single GPU and are instead sharded, or split, across multiple GPUs.

To execute a matrix multiplication, each GPU performs a smaller calculation in parallel using its local shard of the weights. This often follows a specific pattern where one layer's weights are split by columns (requiring no communication after the multiplication) and the next layer's weights are split by rows. This row-parallel approach produces partial results on each GPU that must then be aggregated through a collective communication operation, such as an All-Reduce, to ensure the final activation tensor is mathematically equivalent to the result that would have been computed on a single, monolithic GPU.

Step 1: Initial Setup - Defining and Splitting Matrices

First, let's define our input matrix X and the weight matrices A and B for our two linear layers. We will use small, simple numbers for clarity.

Let our input X be a 2x4 matrix:

X=(12345678) X = \begin{pmatrix} 1 & 2 & 3 & 4 \newline 5 & 6 & 7 & 8 \end{pmatrix}

Let the first weight matrix A be a 4x6 matrix:

A=(123456789123456789123456) A = \begin{pmatrix} 1 & 2 & 3 & 4 & 5 & 6 \newline 7 & 8 & 9 & 1 & 2 & 3 \newline 4 & 5 & 6 & 7 & 8 & 9 \newline 1 & 2 & 3 & 4 & 5 & 6 \end{pmatrix}

Let the second weight matrix B be a 6x4 matrix:

B=(654321789865432112345678) B = \begin{pmatrix} 6 & 5 & 4 & 3 \newline 2 & 1 & 7 & 8 \newline 9 & 8 & 6 & 5 \newline 4 & 3 & 2 & 1 \newline 1 & 2 & 3 & 4 \newline 5 & 6 & 7 & 8 \end{pmatrix}

Now, we split these matrices across our two GPUs.

  • Matrix A is split vertically (by columns). GPU 1 gets the first 3 columns (A1), and GPU 2 gets the last 3 columns (A2). This is known as column parallelism.

A1=(123789456123) A_1 = \begin{pmatrix} 1 & 2 & 3 \newline 7 & 8 & 9 \newline 4 & 5 & 6 \newline 1 & 2 & 3 \end{pmatrix} (on GPU 1)

A2=(456123789456) A_2 = \begin{pmatrix} 4 & 5 & 6 \newline 1 & 2 & 3 \newline 7 & 8 & 9 \newline 4 & 5 & 6 \end{pmatrix} (on GPU 2)

  • Matrix B is split horizontally (by rows). GPU 1 gets the first 3 rows (B1), and GPU 2 gets the last 3 rows (B2). This is known as row parallelism.

B1=(654321789865) B_1 = \begin{pmatrix} 6 & 5 & 4 & 3 \newline 2 & 1 & 7 & 8 \newline 9 & 8 & 6 & 5 \end{pmatrix} (on GPU 1)

B2=(432112345678) B_2 = \begin{pmatrix} 4 & 3 & 2 & 1 \newline 1 & 2 & 3 & 4 \newline 5 & 6 & 7 & 8 \end{pmatrix} (on GPU 2)

We split Matrix A by columns, and Matrix B by rows, so as to minimise communication costs incurred by the all-reduce operation. I will go into more detail at the next section on why we split Matrix A by columns, and Matrix B by rows.

Step 2: Forward Pass for Y = GeLU(X * A)

This operation is parallelized using the column-sharded matrix A. The input X is broadcast to both GPUs.

On GPU 1: Calculate Y_intermediate_1 = X * A1

We perform matrix multiplication. The output will be a 2x3 matrix.

Yintermediate1=(12345678)×(123789456123) Y_{intermediate_1} = \begin{pmatrix} 1 & 2 & 3 & 4 \newline 5 & 6 & 7 & 8 \end{pmatrix} \times \begin{pmatrix} 1 & 2 & 3 \newline 7 & 8 & 9 \newline 4 & 5 & 6 \newline 1 & 2 & 3 \end{pmatrix}

Calculating each element:

  • Y_intermediate_1 = (1*1) + (2*7) + (3*4) + (4*1) = 1 + 14 + 12 + 4 = 31
  • Y_intermediate_1 = (1*2) + (2*8) + (3*5) + (4*2) = 2 + 16 + 15 + 8 = 41
  • Y_intermediate_1 = (1*3) + (2*9) + (3*6) + (4*3) = 3 + 18 + 18 + 12 = 51
  • Y_intermediate_1 = (5*1) + (6*7) + (7*4) + (8*1) = 5 + 42 + 28 + 8 = 83
  • Y_intermediate_1 = (5*2) + (6*8) + (7*5) + (8*2) = 10 + 48 + 35 + 16 = 109
  • Y_intermediate_1 = (5*3) + (6*9) + (7*6) + (8*3) = 15 + 54 + 42 + 24 = 135

So, the result of the matrix multiplication on GPU 1 is:

Yintermediate1=(31415183109135) Y_{intermediate_1} = \begin{pmatrix} 31 & 41 & 51 \newline 83 & 109 & 135 \end{pmatrix}

On GPU 2: Calculate Y_intermediate_2 = X * A2

Similarly, we perform the multiplication on GPU 2.

Yintermediate2=(12345678)×(456123789456) Y_{intermediate_2} = \begin{pmatrix} 1 & 2 & 3 & 4 \newline 5 & 6 & 7 & 8 \end{pmatrix} \times \begin{pmatrix} 4 & 5 & 6 \newline 1 & 2 & 3 \newline 7 & 8 & 9 \newline 4 & 5 & 6 \end{pmatrix}

Calculating each element:

  • Y_intermediate_2 = (1*4) + (2*1) + (3*7) + (4*4) = 4 + 2 + 21 + 16 = 43
  • Y_intermediate_2 = (1*5) + (2*2) + (3*8) + (4*5) = 5 + 4 + 24 + 20 = 53
  • Y_intermediate_2 = (1*6) + (2*3) + (3*9) + (4*6) = 6 + 6 + 27 + 24 = 63
  • Y_intermediate_2 = (5*4) + (6*1) + (7*7) + (8*4) = 20 + 6 + 49 + 32 = 107
  • Y_intermediate_2 = (5*5) + (6*2) + (7*8) + (8*5) = 25 + 12 + 56 + 40 = 133
  • Y_intermediate_2 = (5*6) + (6*3) + (7*9) + (8*6) = 30 + 18 + 63 + 48 = 159

So, the result of the matrix multiplication on GPU 2 is:

Yintermediate2=(435363107133159) Y_{intermediate_2} = \begin{pmatrix} 43 & 53 & 63 \newline 107 & 133 & 159 \end{pmatrix}

Apply GeLU Activation Function

Next, the Gaussian Error Linear Unit (GeLU) activation function is applied element-wise to the intermediate results on each GPU. The GeLU function is defined as:

GELU(x)=xΦ(x) GELU(x) = x \cdot \Phi(x)

where Φ(x)\Phi(x) is the Cumulative Distribution Function (CDF) for the standard normal distribution. A common and precise formula is:

GELU(x)=0.5x(1+erf(x/2)) GELU(x) = 0.5 \cdot x \cdot (1 + erf(x / \sqrt{2}))

For simplicity in this manual example, let's assume GeLU(x) ≈ x for large positive x and GeLU(x) ≈ 0 for large negative x. Since all our numbers are large and positive, we'll approximate GeLU(x) = x. In a real scenario, the exact formula would be applied.

(Note: The erf function produces an "S"-shaped sigmoid curve that passes through the origin. As the input x gets very large and positive, erf(x) gets very close to 1. As the input gets very large and negative, erf(x) gets very close to -1.)

So, the outputs on each GPU are:

Y1=GeLU(Yintermediate1)(31415183109135) Y_1 = GeLU(Y_{intermediate_1}) \approx \begin{pmatrix} 31 & 41 & 51 \newline 83 & 109 & 135 \end{pmatrix} (on GPU 1)

Y2=GeLU(Yintermediate2)(435363107133159) Y_2 = GeLU(Y_{intermediate_2}) \approx \begin{pmatrix} 43 & 53 & 63 \newline 107 & 133 & 159 \end{pmatrix} (on GPU 2)

These two matrices, Y1 and Y2, represent the sharded output of the first layer. They remain on their respective GPUs and become the input for the next step.

Step 3: Forward Pass for Z = Dropout(Y * B)

This operation is parallelized using the row-sharded matrix B. The inputs Y1 and Y2 are already on the correct GPUs.

On GPU 1: Calculate Z_partial_1 = Y1 * B1

We multiply the local activation Y1 with the local weight shard B1.

Zpartial1=(31415183109135)×(654321789865) Z_{partial_1} = \begin{pmatrix} 31 & 41 & 51 \newline 83 & 109 & 135 \end{pmatrix} \times \begin{pmatrix} 6 & 5 & 4 & 3 \newline 2 & 1 & 7 & 8 \newline 9 & 8 & 6 & 5 \end{pmatrix}

Calculating each element:

  • Z_partial_1 = (31*6) + (41*2) + (51*9) = 186 + 82 + 459 = 727
  • Z_partial_1 = (31*5) + (41*1) + (51*8) = 155 + 41 + 408 = 604
  • Z_partial_1 = (31*4) + (41*7) + (51*6) = 124 + 287 + 306 = 717
  • Z_partial_1 = (31*3) + (41*8) + (51*5) = 93 + 328 + 255 = 676
  • Z_partial_1 = (83*6) + (109*2) + (135*9) = 498 + 218 + 1215 = 1931
  • Z_partial_1 = (83*5) + (109*1) + (135*8) = 415 + 109 + 1080 = 1604
  • Z_partial_1 = (83*4) + (109*7) + (135*6) = 332 + 763 + 810 = 1905
  • Z_partial_1 = (83*3) + (109*8) + (135*5) = 249 + 872 + 675 = 1796

Result on GPU 1:

Zpartial1=(7276047176761931160419051796) Z_{partial_1} = \begin{pmatrix} 727 & 604 & 717 & 676 \newline 1931 & 1604 & 1905 & 1796 \end{pmatrix}

On GPU 2: Calculate Z_partial_2 = Y2 * B2

Simultaneously, on GPU 2:

Zpartial2=(435363107133159)×(432112345678) Z_{partial_2} = \begin{pmatrix} 43 & 53 & 63 \newline 107 & 133 & 159 \end{pmatrix} \times \begin{pmatrix} 4 & 3 & 2 & 1 \newline 1 & 2 & 3 & 4 \newline 5 & 6 & 7 & 8 \end{pmatrix}

Calculating each element:

  • Z_partial_2 = (43*4) + (53*1) + (63*5) = 172 + 53 + 315 = 540
  • Z_partial_2 = (43*3) + (53*2) + (63*6) = 129 + 106 + 378 = 613
  • Z_partial_2 = (43*2) + (53*3) + (63*7) = 86 + 159 + 441 = 686
  • Z_partial_2 = (43*1) + (53*4) + (63*8) = 43 + 212 + 504 = 759
  • Z_partial_2 = (107*4) + (133*1) + (159*5) = 428 + 133 + 795 = 1356
  • Z_partial_2 = (107*3) + (133*2) + (159*6) = 321 + 266 + 954 = 1541
  • Z_partial_2 = (107*2) + (133*3) + (159*7) = 214 + 399 + 1113 = 1726
  • Z_partial_2 = (107*1) + (133*4) + (159*8) = 107 + 532 + 1272 = 1911

Result on GPU 2:

Zpartial2=(5406136867591356154117261911) Z_{partial_2} = \begin{pmatrix} 540 & 613 & 686 & 759 \newline 1356 & 1541 & 1726 & 1911 \end{pmatrix}

All-Reduce Communication Step

The logic behind splitting B by rows is that the final output Z is the sum of the partial results from each GPU. This requires a communication step called an All-Reduce. Each GPU sends its partial result to the others and receives their partial results, summing them all together.

Zunactivated=Zpartial1+Zpartial2 Z_{unactivated} = Z_{partial_1} + Z_{partial_2}
Zunactivated=(7276047176761931160419051796)+(5406136867591356154117261911) Z_{unactivated} = \begin{pmatrix} 727 & 604 & 717 & 676 \newline 1931 & 1604 & 1905 & 1796 \end{pmatrix} + \begin{pmatrix} 540 & 613 & 686 & 759 \newline 1356 & 1541 & 1726 & 1911 \end{pmatrix}

Adding each element:

Zunactivated=(727+540604+613717+686676+7591931+13561604+15411905+17261796+1911) Z_{unactivated} = \begin{pmatrix} 727+540 & 604+613 & 717+686 & 676+759 \newline 1931+1356 & 1604+1541 & 1905+1726 & 1796+1911 \end{pmatrix}
Zunactivated=(12671217140314353287314536313707) Z_{unactivated} = \begin{pmatrix} 1267 & 1217 & 1403 & 1435 \newline 3287 & 3145 & 3631 & 3707 \end{pmatrix}

After the All-Reduce, the complete Z_unactivated matrix is present on both GPUs.

Apply Dropout

Dropout is a regularization technique used during training to prevent overfitting. It randomly sets a fraction of the output elements to zero at each update step.

Let's assume a dropout rate of 50%. We create a "dropout mask" with the same dimensions as Z_unactivated, where each position is randomly 0 or 1.

Example dropout mask:

Mask=(10010110) Mask = \begin{pmatrix} 1 & 0 & 0 & 1 \newline 0 & 1 & 1 & 0 \end{pmatrix}

We multiply our Z_unactivated by this mask:

Zmasked=ZunactivatedMask=(1267112170140301435132870314513631137070) Z_{masked} = Z_{unactivated} \odot Mask = \begin{pmatrix} 1267 \cdot 1 & 1217 \cdot 0 & 1403 \cdot 0 & 1435 \cdot 1 \newline 3287 \cdot 0 & 3145 \cdot 1 & 3631 \cdot 1 & 3707 \cdot 0 \end{pmatrix}
Zmasked=(12670014350314536310) Z_{masked} = \begin{pmatrix} 1267 & 0 & 0 & 1435 \newline 0 & 3145 & 3631 & 0 \end{pmatrix}

Finally, to ensure the expected sum of outputs remains the same, the remaining non-zero elements are scaled up by a factor of 1 / (1 - dropout_rate). In our case, this is 1 / (1 - 0.5) = 2. This is called "inverted dropout".

Z=Zmasked×2=(25340028700629072620) Z = Z_{masked} \times 2 = \begin{pmatrix} 2534 & 0 & 0 & 2870 \newline 0 & 6290 & 7262 & 0 \end{pmatrix}

This final matrix Z is the output of our tensor-parallel multi-layer perceptron block.

Summary of the forward pass

  1. Start: We have an input X and two GPUs.
  2. Layer 1 (Column Parallelism):
    • The weight matrix A is split into columns (A1, A2).
    • Each GPU calculates a part of the output: X * A1 on GPU 1 and X * A2 on GPU 2.
    • An activation function (GeLU) is applied independently on each GPU to these partial results (Y1, Y2). No communication is needed.
  3. Layer 2 (Row Parallelism):
    • The weight matrix B is split into rows (B1, B2).
    • Each GPU multiplies its local activation from the previous step with its local weight shard: Y1 * B1 on GPU 1 and Y2 * B2 on GPU 2.
    • An All-Reduce operation sums the partial results from all GPUs. This is the critical communication step.
    • The final operation (Dropout) is applied to the complete, summed tensor, resulting in the final output Z.

Why split Matrix A by columns, Matrix B by rows

The goal is to sequence the operations to ensure the expensive All-Reduce step (where GPUs must synchronize and sum their results) happens as infrequently as possible. To understand this, let's define the two operations and their communication costs.

1. Column Parallelism (What we do to A)

When we split a weight matrix A by columns, the input X is broadcast to all GPUs. Each GPU then calculates a vertical "slice" of the output.

  • Mathematics: Y = X * [A1, A2] = [X * A1, X * A2] = [Y1, Y2]
  • Result: The output Y is naturally sharded (split) across the GPUs. Y1 is on GPU 1 and Y2 is on GPU 2.
  • Communication Cost (Forward Pass): Zero. No communication is needed to produce the sharded output Y. Each GPU works independently. The output is left "as is" on each device, ready for the next step.

2. Row Parallelism (What we do to B)

When we split a weight matrix B by rows, the input must already be sharded in a corresponding way (which is exactly what Column Parallelism just produced!). Each GPU multiplies its local input shard with its local weight shard.

  • Mathematics: Z = Y * B = [Y1, Y2] * [B1; B2] = (Y1 * B1) + (Y2 * B2) (Note: ; denotes stacking rows)
  • Result: Each GPU calculates a partial result. To get the final, correct output Z, these partial results must be summed together.
  • Communication Cost (Forward Pass): One All-Reduce operation. This is the expensive step where Z_partial_1 and Z_partial_2 are added across all GPUs.

Analyzing the Standard Column -> Row Pattern

Now, let's see how these two pieces fit together in our MLP block:

  1. First Layer: Y = GeLU(X * A) (Column Parallelism)

    • We split A by columns.
    • The multiplication X * A produces a sharded activation [Y1, Y2] without any need for communication.
    • Forward Communication Cost: 0
  2. Second Layer: Z = Dropout(Y * B) (Row Parallelism)

    • We split B by rows.
    • The sharded activation [Y1, Y2] from the previous step is the perfect input for this layer.
    • The multiplication Y * B produces partial results Z_partial_1 and Z_partial_2.
    • To get the final output Z, we must perform an All-Reduce to sum the partial results.
    • Forward Communication Cost: 1 All-Reduce

By sequencing the operations this way (Column -> Row), we have successfully computed the output for a full two-layer block with only one communication step.

What Happens If We Switch? (Row -> Column)

Let's imagine we tried to do it the other way around.

  1. First Layer: Y = GeLU(X * A) (Row Parallelism)

    • We split A by rows (A1, A2).
    • For the math X * A to work, we would also have to split the input X by columns (X1, X2).
    • The operation becomes (X1 * A1) + (X2 * A2).
    • This produces partial results which must be summed.
    • Forward Communication Cost: 1 All-Reduce. The output Y is now a complete (un-sharded) tensor, identical on all GPUs.
  2. Second Layer: Z = Dropout(Y * B) (Column Parallelism)

    • We split B by columns (B1, B2).
    • The input Y is a complete tensor from the previous step.
    • The operation Y * [B1, B2] produces a sharded output [Z1, Z2] without communication.
    • Forward Communication Cost: 0

Why column then row?

As you can see, both patterns (Column -> Row and Row -> Column) result in the exact same amount of communication: one All-Reduce per two-layer block.

So why is the Column -> Row pattern the standard convention used in frameworks like Megatron-LM?

The reason is modularity and efficiency in the computational graph. The standard Transformer block consists of an Attention block followed by an MLP block. The Column -> Row pattern creates a perfect compositional unit:

  • The MLP block takes a complete tensor as input (like X).
  • It performs Column Parallel -> Row Parallel.
  • It outputs a complete tensor (Z, after the All-Reduce).

This complete tensor Z can then be fed directly into the next Attention block, which will also start with a column-parallel operation. This creates a clean, predictable, and efficient computational flow where the "sharded" nature of the activations is kept internal to the block, and communication is neatly packaged.

In short, the Column -> Row sequence is a deliberate design pattern that cleverly delays the necessary communication to the very last step within a block, creating a modular and efficient structure for building large neural networks.

Backward pass (backpropagation)

The goal of the backward pass is to calculate the gradients of the loss function with respect to all the learnable parameters (the weights A and B) and the input (X). These gradients tell us how to adjust the weights to reduce the loss.

The process starts with the gradient of the loss with respect to the final output, which we'll call dL/dZ or simply dZ. This initial gradient propagates backward through the network. For this example, we'll assume the process has already begun and the initial gradient dZ is a 2x4 matrix of ones.

Initial Gradient (assumed):

dZ=(11111111) dZ = \begin{pmatrix} 1 & 1 & 1 & 1 \newline 1 & 1 & 1 & 1 \end{pmatrix}

Step 1: Backward Pass through Dropout

The first step is to backpropagate through the Dropout layer. The gradient is only allowed to pass through the neurons that were not "dropped out" during the forward pass. This means we apply the same dropout mask and scaling factor.

Mathematics

The gradient with respect to the pre-dropout output (Z_unactivated) is calculated by element-wise multiplying the incoming gradient dZ with the dropout mask and scaling it by the same factor used in the forward pass.

LZunactivated=LZMask×11dropout_rate \frac{\partial L}{\partial Z_{unactivated}} = \frac{\partial L}{\partial Z} \odot Mask \times \frac{1}{1 - \text{dropout\_rate}}

Calculation

From our forward pass, our mask and scaling factor were:
Mask=(10010110)Mask = \begin{pmatrix} 1 & 0 & 0 & 1 \newline 0 & 1 & 1 & 0 \end{pmatrix} , Scale = 2

The gradient dZ is passed then through this inverted dropout layer:

dZmasked=dZMask=(11111111)(10010110)=(10010110) dZ_{masked} = dZ \odot Mask = \begin{pmatrix} 1 & 1 & 1 & 1 \newline 1 & 1 & 1 & 1 \end{pmatrix} \odot \begin{pmatrix} 1 & 0 & 0 & 1 \newline 0 & 1 & 1 & 0 \end{pmatrix} = \begin{pmatrix} 1 & 0 & 0 & 1 \newline 0 & 1 & 1 & 0 \end{pmatrix}
dZunactivated=dZmasked×2=(20020220) dZ_{unactivated} = dZ_{masked} \times 2 = \begin{pmatrix} 2 & 0 & 0 & 2 \newline 0 & 2 & 2 & 0 \end{pmatrix}

This gradient, dZ_unactivated, is now available on both GPUs, as the All-Reduce in the forward pass made the output Z_unactivated identical on both.

Step 2: Backward Pass for Z = Dropout(Y * B)

Now we backpropagate through the second linear layer. This involves calculating the gradients with respect to the weights (dB) and the activations from the previous layer (dY).

Mathematics for Gradients dY and dB

Given the operation Z_unactivated = Y * B, the gradients are:

  1. Gradient w.r.t. Y: LY=LZunactivatedBT\frac{\partial L}{\partial Y} = \frac{\partial L}{\partial Z_{unactivated}} \cdot B^T
  2. Gradient w.r.t. B: LB=YTLZunactivated\frac{\partial L}{\partial B} = Y^T \cdot \frac{\partial L}{\partial Z_{unactivated}}

In our tensor parallel setup (row parallelism), Y = [Y1, Y2] and B is split into B1 and B2. The forward pass involved local multiplications (Y1*B1, Y2*B2) followed by an All-Reduce. The backward pass reverses this.

  • For dY (Activation Gradients): The gradient dZ_unactivated is multiplied by the transpose of the local weight shard. This is a local computation on each GPU and requires no communication.
    • On GPU 1: dY1=dZunactivatedB1TdY_1 = dZ_{unactivated} \cdot B_1^T
    • On GPU 2: dY2=dZunactivatedB2TdY_2 = dZ_{unactivated} \cdot B_2^T
  • For dB (Weight Gradients): The transpose of the local activation Y is multiplied by the incoming gradient.
    • On GPU 1: dB1=Y1TdZunactivateddB_1 = Y_1^T \cdot dZ_{unactivated}
    • On GPU 2: dB2=Y2TdZunactivateddB_2 = Y_2^T \cdot dZ_{unactivated}

Calculation for dY

On GPU 1: Calculate dY1 = dZ_unactivated * B1^T

B1T=(629518476385) B_1^T = \begin{pmatrix} 6 & 2 & 9 \newline 5 & 1 & 8 \newline 4 & 7 & 6 \newline 3 & 8 & 5 \end{pmatrix}
dY1=(20020220)×(629518476385)=((26+23)(22+28)(29+25)(25+24)(21+27)(28+26))=(182028181628) dY_1 = \begin{pmatrix} 2 & 0 & 0 & 2 \newline 0 & 2 & 2 & 0 \end{pmatrix} \times \begin{pmatrix} 6 & 2 & 9 \newline 5 & 1 & 8 \newline 4 & 7 & 6 \newline 3 & 8 & 5 \end{pmatrix} = \begin{pmatrix} (2*6+2*3) & (2*2+2*8) & (2*9+2*5) \newline (2*5+2*4) & (2*1+2*7) & (2*8+2*6) \end{pmatrix} = \begin{pmatrix} 18 & 20 & 28 \newline 18 & 16 & 28 \end{pmatrix}

On GPU 2: Calculate dY2 = dZ_unactivated * B2^T

B2T=(415326237148) B_2^T = \begin{pmatrix} 4 & 1 & 5 \newline 3 & 2 & 6 \newline 2 & 3 & 7 \newline 1 & 4 & 8 \end{pmatrix}
dY2=(20020220)×(415326237148)=((24+21)(21+24)(25+28)(23+22)(22+23)(26+27))=(101026101026) dY_2 = \begin{pmatrix} 2 & 0 & 0 & 2 \newline 0 & 2 & 2 & 0 \end{pmatrix} \times \begin{pmatrix} 4 & 1 & 5 \newline 3 & 2 & 6 \newline 2 & 3 & 7 \newline 1 & 4 & 8 \end{pmatrix} = \begin{pmatrix} (2*4+2*1) & (2*1+2*4) & (2*5+2*8) \newline (2*3+2*2) & (2*2+2*3) & (2*6+2*7) \end{pmatrix} = \begin{pmatrix} 10 & 10 & 26 \newline 10 & 10 & 26 \end{pmatrix}

Calculation for dB

On GPU 1: Calculate dB1 = Y1^T * dZ_unactivated

Y1T=(31834110951135) Y_1^T = \begin{pmatrix} 31 & 83 \newline 41 & 109 \newline 51 & 135 \end{pmatrix}
dB1=(31834110951135)×(20020220)=((312)(832)(832)(312)(412)(1092)(1092)(412)(512)(1352)(1352)(512))=(62166166628221821882102270270102) dB_1 = \begin{pmatrix} 31 & 83 \newline 41 & 109 \newline 51 & 135 \end{pmatrix} \times \begin{pmatrix} 2 & 0 & 0 & 2 \newline 0 & 2 & 2 & 0 \end{pmatrix} = \begin{pmatrix} (31*2) & (83*2) & (83*2) & (31*2) \newline (41*2) & (109*2) & (109*2) & (41*2) \newline (51*2) & (135*2) & (135*2) & (51*2) \end{pmatrix} = \begin{pmatrix} 62 & 166 & 166 & 62 \newline 82 & 218 & 218 & 82 \newline 102 & 270 & 270 & 102 \end{pmatrix}

On GPU 2: Calculate dB2 = Y2^T * dZ_unactivated

Y2T=(431075313363159) Y_2^T = \begin{pmatrix} 43 & 107 \newline 53 & 133 \newline 63 & 159 \end{pmatrix}
dB2=(431075313363159)×(20020220)=((432)(1072)(1072)(432)(532)(1332)(1332)(532)(632)(1592)(1592)(632))=(8621421486106266266106126318318126) dB_2 = \begin{pmatrix} 43 & 107 \newline 53 & 133 \newline 63 & 159 \end{pmatrix} \times \begin{pmatrix} 2 & 0 & 0 & 2 \newline 0 & 2 & 2 & 0 \end{pmatrix} = \begin{pmatrix} (43*2) & (107*2) & (107*2) & (43*2) \newline (53*2) & (133*2) & (133*2) & (53*2) \newline (63*2) & (159*2) & (159*2) & (63*2) \end{pmatrix} = \begin{pmatrix} 86 & 214 & 214 & 86 \newline 106 & 266 & 266 & 106 \newline 126 & 318 & 318 & 126 \end{pmatrix}

Step 3: Backward Pass through GeLU Activation

Next, the gradients dY1 and dY2 are passed backward through the GeLU activation function.

Mathematics

The chain rule dictates that we multiply the incoming gradient by the derivative of the activation function evaluated at its original input.

LYintermediate=LYGeLU(Yintermediate) \frac{\partial L}{\partial Y_{intermediate}} = \frac{\partial L}{\partial Y} \odot GeLU'(Y_{intermediate})

The derivative of GeLU is GeLU(x)=0.5(1+erf(x/2))+x2πex2/2GeLU'(x) = 0.5 \cdot (1 + erf(x/\sqrt{2})) + \frac{x}{\sqrt{2\pi}}e^{-x^2/2} . As in the forward pass, for the large positive numbers in our example, this derivative is very close to 1. We will use this approximation for simplicity.

GeLU(x)1 GeLU'(x) \approx 1

Calculation

We simply pass the gradients through, as multiplying by 1 does not change them.
On GPU 1:

dYintermediate1=dY11=(182028181628) dY_{intermediate_1} = dY_1 \odot 1 = \begin{pmatrix} 18 & 20 & 28 \newline 18 & 16 & 28 \end{pmatrix}

On GPU 2:

dYintermediate2=dY21=(101026101026) dY_{intermediate_2} = dY_2 \odot 1 = \begin{pmatrix} 10 & 10 & 26 \newline 10 & 10 & 26 \end{pmatrix}

Step 4: Backward Pass for Y = GeLU(X * A)

Finally, we backpropagate through the first linear layer. This computes the gradients dA (for the weights) and dX (for the input).

Mathematics for Gradients dX and dA

Given Y_intermediate = X * A, the gradients are:

  1. Gradient w.r.t. X: LX=LYintermediateAT\frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y_{intermediate}} \cdot A^T
  2. Gradient w.r.t. A: LA=XTLYintermediate\frac{\partial L}{\partial A} = X^T \cdot \frac{\partial L}{\partial Y_{intermediate}}

In our tensor parallel setup (column parallelism), the input X was broadcast, and the output Y_intermediate was sharded (Y_intermediate_1, Y_intermediate_2).

  • For dA (Weight Gradients): The calculation is local.
    • On GPU 1: dA1=XTdYintermediate1dA_1 = X^T \cdot dY_{intermediate_1}
    • On GPU 2: dA2=XTdYintermediate2dA_2 = X^T \cdot dY_{intermediate_2}
  • For dX (Input Gradients): Each GPU calculates a partial gradient for dX. These partial gradients must be summed together using an All-Reduce operation.
    • On GPU 1: dXpartial1=dYintermediate1A1TdX_{partial_1} = dY_{intermediate_1} \cdot A_1^T
    • On GPU 2: dXpartial2=dYintermediate2A2TdX_{partial_2} = dY_{intermediate_2} \cdot A_2^T
    • Final Gradient: dX=dXpartial1+dXpartial2dX = dX_{partial_1} + dX_{partial_2}

Calculation for dA

On GPU 1: Calculate dA1 = X^T * dY_intermediate_1

XT=(15263748) X^T = \begin{pmatrix} 1 & 5 \newline 2 & 6 \newline 3 & 7 \newline 4 & 8 \end{pmatrix}
dA1=(15263748)×(182028181628)=((18+90)(20+80)(28+140)(36+108)(40+96)(56+168)(54+126)(60+112)(84+196)(72+144)(80+128)(112+224))=(108100168144136224180172280216208336) dA_1 = \begin{pmatrix} 1 & 5 \newline 2 & 6 \newline 3 & 7 \newline 4 & 8 \end{pmatrix} \times \begin{pmatrix} 18 & 20 & 28 \newline 18 & 16 & 28 \end{pmatrix} = \begin{pmatrix} (18+90) & (20+80) & (28+140) \newline (36+108) & (40+96) & (56+168) \newline (54+126) & (60+112) & (84+196) \newline (72+144) & (80+128) & (112+224) \end{pmatrix} = \begin{pmatrix} 108 & 100 & 168 \newline 144 & 136 & 224 \newline 180 & 172 & 280 \newline 216 & 208 & 336 \end{pmatrix}

On GPU 2: Calculate dA2 = X^T * dY_intermediate_2

dA2=(15263748)×(101026101026)=((10+50)(10+50)(26+130)(20+60)(20+60)(52+156)(30+70)(30+70)(78+182)(40+80)(40+80)(104+208))=(60601568080208100100260120120312) dA_2 = \begin{pmatrix} 1 & 5 \newline 2 & 6 \newline 3 & 7 \newline 4 & 8 \end{pmatrix} \times \begin{pmatrix} 10 & 10 & 26 \newline 10 & 10 & 26 \end{pmatrix} = \begin{pmatrix} (10+50) & (10+50) & (26+130) \newline (20+60) & (20+60) & (52+156) \newline (30+70) & (30+70) & (78+182) \newline (40+80) & (40+80) & (104+208) \end{pmatrix} = \begin{pmatrix} 60 & 60 & 156 \newline 80 & 80 & 208 \newline 100 & 100 & 260 \newline 120 & 120 & 312 \end{pmatrix}

Calculation for dX (with All-Reduce)

On GPU 1: Calculate dX_partial_1 = dY_intermediate_1 * A1^T

A1T=(174128523963) A_1^T = \begin{pmatrix} 1 & 7 & 4 & 1 \newline 2 & 8 & 5 & 2 \newline 3 & 9 & 6 & 3 \end{pmatrix}
dXpartial1=(182028181628)×(174128523963)=((18+40+84)(126+160+252)(72+100+168)(18+40+84)(18+32+84)(126+128+252)(72+80+168)(18+32+84))=(142538340142134506320134) dX_{partial_1} = \begin{pmatrix} 18 & 20 & 28 \newline 18 & 16 & 28 \end{pmatrix} \times \begin{pmatrix} 1 & 7 & 4 & 1 \newline 2 & 8 & 5 & 2 \newline 3 & 9 & 6 & 3 \end{pmatrix} = \begin{pmatrix} (18+40+84) & (126+160+252) & (72+100+168) & (18+40+84) \newline (18+32+84) & (126+128+252) & (72+80+168) & (18+32+84) \end{pmatrix} = \begin{pmatrix} 142 & 538 & 340 & 142 \newline 134 & 506 & 320 & 134 \end{pmatrix}

On GPU 2: Calculate dX_partial_2 = dY_intermediate_2 * A2^T

A2T=(417452856396) A_2^T = \begin{pmatrix} 4 & 1 & 7 & 4 \newline 5 & 2 & 8 & 5 \newline 6 & 3 & 9 & 6 \end{pmatrix}
dXpartial2=(101026101026)×(417452856396)=((40+50+156)(10+20+78)(70+80+234)(40+50+156)(40+50+156)(10+20+78)(70+80+234)(40+50+156))=(246108384246246108384246) dX_{partial_2} = \begin{pmatrix} 10 & 10 & 26 \newline 10 & 10 & 26 \end{pmatrix} \times \begin{pmatrix} 4 & 1 & 7 & 4 \newline 5 & 2 & 8 & 5 \newline 6 & 3 & 9 & 6 \end{pmatrix} = \begin{pmatrix} (40+50+156) & (10+20+78) & (70+80+234) & (40+50+156) \newline (40+50+156) & (10+20+78) & (70+80+234) & (40+50+156) \end{pmatrix} = \begin{pmatrix} 246 & 108 & 384 & 246 \newline 246 & 108 & 384 & 246 \end{pmatrix}

All-Reduce Communication Step
Finally, the partial gradients for X are summed across GPUs.

dX=dXpartial1+dXpartial2=(142538340142134506320134)+(246108384246246108384246)=(388646724388380614704380) dX = dX_{partial_1} + dX_{partial_2} = \begin{pmatrix} 142 & 538 & 340 & 142 \newline 134 & 506 & 320 & 134 \end{pmatrix} + \begin{pmatrix} 246 & 108 & 384 & 246 \newline 246 & 108 & 384 & 246 \end{pmatrix} = \begin{pmatrix} 388 & 646 & 724 & 388 \newline 380 & 614 & 704 & 380 \end{pmatrix}

This final dX is the gradient with respect to the input of the MLP block, ready to be passed further backward to the previous layer in the network. The weight gradients dA1, dA2, dB1, and dB2 are used by the optimizer to update the model's weights on each respective GPU.

The Duality of Communication in Forward vs. Backward Pass

A fundamental principle of backpropagation in tensor parallelism is that the communication patterns are reversed.

  • If a layer requires an All-Reduce in the forward pass, its corresponding backward pass for the input gradient requires no communication (Identity).
  • If a layer requires no communication (Identity) in the forward pass, its corresponding backward pass for the input gradient requires an All-Reduce.

Here is a simple table summarizing the communication for the input gradients (dX and dY):

Layer Type (Forward) Forward Pass Communication (f) Backward Pass Communication (g)
Column Parallelism Identity (No communication) All-Reduce (Sum partial results)
Row Parallelism All-Reduce (Sum partial results) Identity (No communication)

(Note: The gradients for the weights, dA and dB, are always local calculations and never require an All-Reduce, so we can focus on the activation gradients.)

Backpropagation Through the Standard Column -> Row Pattern

Now, let's trace the backward pass through our standard two-layer block, keeping the table above in mind. The process starts with dZ, the gradient from the layer above. Because the forward pass ended with an All-Reduce, Z was a complete tensor on all GPUs, which means dZ is also a complete tensor.

  1. Backprop through Layer 2: Z = Dropout(Y * B) (Row Parallelism)

    • Forward Pass required: All-Reduce.
    • Backward Pass requires: Identity (No communication).
    • Calculation (dY = dZ * B^T): The complete gradient dZ is multiplied by the local weight shard B^T on each GPU. This directly produces the correct sharded activation gradient dY ([dY1, dY2]) without any need for communication.
  2. Backprop through Layer 1: Y = GeLU(X * A) (Column Parallelism)

    • Forward Pass required: Identity.
    • Backward Pass requires: All-Reduce.
    • Calculation (dX = dY * A^T): The sharded gradient dY from the previous step is multiplied by the local weight shard A^T. This produces a partial gradient for the input, dX_partial. To get the final, complete gradient dX, these partial results must be summed across all GPUs. This is the All-Reduce step.

As you can see, the backward pass for our Column -> Row block has exactly one All-Reduce operation, just like the forward pass.

What if we used Row -> Column?

The result would be perfectly symmetric:

  1. Backprop through Layer 2 (Column Parallelism): The incoming gradient dZ would be sharded. This step would require an All-Reduce to produce a complete gradient dY.
  2. Backprop through Layer 1 (Row Parallelism): The complete gradient dY would be the input. This step would require no communication (Identity) to produce the sharded input gradient dX.

The total communication remains one All-Reduce in the forward pass and one in the backward pass.

The Preference is for Modularity, Not Cost

The total cost is the same. So why do we prefer Column -> Row?

The reason is that this pattern creates a perfectly self-contained, modular block.

  • Input: The block takes a complete tensor (X).
  • Internal State: It creates a sharded tensor (Y) internally.
  • Output: It performs an All-Reduce at the end and outputs a complete tensor (Z).

The backward pass mirrors this perfectly:

  • Input: It takes a complete gradient (dZ).
  • Internal State: It creates a sharded gradient (dY) internally.
  • Output: It performs an All-Reduce at the end and outputs a complete gradient (dX).

This means you can stack these Transformer blocks one after another without worrying about whether the tensor you're passing is complete or sharded. Each block handles its own internal sharding and communication, always presenting a clean, complete tensor at its boundaries. This makes the overall architecture of a large model much simpler to design, implement, and debug.

Communication cost of tensor parallelism

In our walkthrough, communication between GPUs did not happen after every mathematical operation. It only occurred at very specific points where partial results needed to be combined to form a complete tensor. Let's pinpoint them:

  1. Forward Pass Communication: This happened at the end of Step 3. To calculate the final Z_unactivated, we needed to sum the partial results from each GPU.
    Zunactivated=Zpartial1+Zpartial2Z_{unactivated} = Z_{partial_1} + Z_{partial_2}
    This summation is an all-reduce operation. GPU 1 sends its Z_partial_1 and receives Z_partial_2; GPU 2 sends Z_partial_2 and receives Z_partial_1. Both then compute the sum.

  2. Backward Pass Communication: This happened at the end of Step 4. To calculate the final input gradient dX, we needed to sum the partial gradients from each GPU.
    dX=dXpartial1+dXpartial2dX = dX_{partial_1} + dX_{partial_2}
    This summation is also an all-reduce operation.

For a single pass through this two-layer block, we perform two full all-reduce operations.

Analyzing the Communication Volume

The formula for the communication volume in tensor parallelism per layer:

8bsh(ndevices1ndevices)8bsh \left( \frac{n_{\text{devices}} - 1}{n_{\text{devices}}} \right)

Source: Stanford CS 336 Language Modeling from Scratch, Lecture 7: Parallelism 1, 1 hour, 1 min 8 sec

  • b: This is the batch size. In our example, the input matrix X had 2 rows, so b=2b = 2 .
  • s: This is the sequence length. In a simple MLP, we can consider this to be 1, so s=1s = 1 .
  • h: This is the hidden size of the tensor being communicated. The tensor we communicated in the forward pass was Z_partial, which had a shape of (2, 4). The size of the last dimension is 4, so h=4h = 4 .
  • n_devices: The number of GPUs, which is ndevices=2n_{\text{devices}} = 2 .

In short, bsh represents the total number of elements in the activation tensor being communicated.

Volume of data in one all-reduce:

In the forward pass all-reduce, the tensor being communicated is Z_partial. The number of elements in this tensor is:

Volumeelements=b×s×h=2×1×4=8 elements Volume_{elements} = b \times s \times h = 2 \times 1 \times 4 = 8 \text{ elements}

The all-reduce operation is a collective communication where each GPU sends its 8 elements and receives 8 elements from its peer.

The term (ndevices1ndevices)\left( \frac{n_{\text{devices}} - 1}{n_{\text{devices}}} \right) in the formula characterizes the volume of a standard ring all-reduce algorithm. For our 2 GPUs, this factor is:

(212)=12 \left( \frac{2 - 1}{2} \right) = \frac{1}{2}

This signifies that each GPU sends and receives data chunks in a way that scales with the number of devices.

The factor of 8 in the formula often accounts for two things: 4 bytes for a standard 32-bit floating-point number and a factor of 2 because communication happens in both the forward and backward passes.

Why the communication cost of tensor parallelism is high

The communication cost of tensor parallelism is high due to the combination of volume and frequency:

  1. High-Cost Operation: The all-reduce is a "collective" operation that is much more complex and latency-sensitive than a simple point-to-point send/receive. It requires synchronization across all GPUs in the group.

  2. Large Data Volume: The amount of data communicated in each all-reduce is the full activation tensor (bsh), which can be very large in modern LLMs.

  3. High Frequency: This expensive, high-volume operation is not performed once, but twice for every single transformer block in the model during a full training step (forward + backward pass).

Therefore, while tensor parallelism is excellent at keeping all GPUs utilized (eliminating the "bubble" of pipeline parallelism), it comes at the cost of constant, high-volume communication across all participating devices. This is why the general recommendation is to use tensor parallel whenever we have low-latency, high-bandwidth interconnects like NVIDIA's NVLink, which are specifically designed to handle this intense communication pattern efficiently.


Code

For those who prefer following the code, the walk through guide for both the forward and backward pass are replicated below using Python.

Setup

import torch
import torch.nn.functional as F

# --------------------------------------------------------------------------
# 1. SETUP: Define tensors from the manual example
# --------------------------------------------------------------------------
# Set requires_grad=True to track gradients
X = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8]], dtype=torch.float32, requires_grad=True)

A = torch.tensor([[1, 2, 3, 4, 5, 6],
                  [7, 8, 9, 1, 2, 3],
                  [4, 5, 6, 7, 8, 9],
                  [1, 2, 3, 4, 5, 6]], dtype=torch.float32, requires_grad=True)

B = torch.tensor([[6, 5, 4, 3],
                  [2, 1, 7, 8],
                  [9, 8, 6, 5],
                  [4, 3, 2, 1],
                  [1, 2, 3, 4],
                  [5, 6, 7, 8]], dtype=torch.float32, requires_grad=True)

# NOTE: For verification, we use an identity function to match the manual
# calculation's approximation of GeLU(x) ≈ x. In a real model, you would
# use F.gelu(x).
gelu = lambda x: x

# Define the fixed dropout mask from the manual example
dropout_mask = torch.tensor([[1, 0, 0, 1],
                             [0, 1, 1, 0]], dtype=torch.float32)
dropout_rate = 0.5
dropout_scale = 1 / (1 - dropout_rate)

print("--- Initial Tensors ---")
print(f"X:\n{X}\n")
print(f"A:\n{A}\n")
print(f"B:\n{B}\n")
Enter fullscreen mode Exit fullscreen mode

Output

--- Initial Tensors ---
X:
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]], requires_grad=True)

A:
tensor([[1., 2., 3., 4., 5., 6.],
        [7., 8., 9., 1., 2., 3.],
        [4., 5., 6., 7., 8., 9.],
        [1., 2., 3., 4., 5., 6.]], requires_grad=True)

B:
tensor([[6., 5., 4., 3.],
        [2., 1., 7., 8.],
        [9., 8., 6., 5.],
        [4., 3., 2., 1.],
        [1., 2., 3., 4.],
        [5., 6., 7., 8.]], requires_grad=True)
Enter fullscreen mode Exit fullscreen mode

Forward pass

# --------------------------------------------------------------------------
# 2. FORWARD PASS
# --------------------------------------------------------------------------
print("--- FORWARD PASS ---")

# --- Simulate Tensor Parallelism on 2 GPUs ---

# Split A by columns (Column Parallelism)
# "GPU 1" gets A1, "GPU 2" gets A2
A1 = A[:, :3]
A2 = A[:, 3:]

# Split B by rows (Row Parallelism)
# "GPU 1" gets B1, "GPU 2" gets B2
B1 = B[:3, :]
B2 = B[3:, :]

# --- Layer 1: Y = GeLU(X * A) ---
# Operation is local to each GPU, no communication needed.

# "GPU 1" computes its part
Y_intermediate_1 = X @ A1
Y1 = gelu(Y_intermediate_1)

# "GPU 2" computes its part
Y_intermediate_2 = X @ A2
Y2 = gelu(Y_intermediate_2)

print(f"Y1 (on GPU 1):\n{Y1}\n")
print(f"Y2 (on GPU 2):\n{Y2}\n")

# --- Layer 2: Z = Dropout(Y * B) ---

# Local matrix multiplication on each GPU
# Note: Y is implicitly [Y1, Y2]
Z_partial_1 = Y1 @ B1
Z_partial_2 = Y2 @ B2

print(f"Z_partial_1 (on GPU 1):\n{Z_partial_1}\n")
print(f"Z_partial_2 (on GPU 2):\n{Z_partial_2}\n")

# *** All-Reduce Communication Step ***
# The partial results are summed across GPUs.
Z_unactivated = Z_partial_1 + Z_partial_2
print(f"Z_unactivated (after All-Reduce):\n{Z_unactivated}\n")

# Apply Dropout
Z_masked = Z_unactivated * dropout_mask
Z = Z_masked * dropout_scale

print(f"Final Output Z:\n{Z}\n")
Enter fullscreen mode Exit fullscreen mode

Output

--- FORWARD PASS ---
Y1 (on GPU 1):
tensor([[ 31.,  41.,  51.],
        [ 83., 109., 135.]], grad_fn=<MmBackward0>)

Y2 (on GPU 2):
tensor([[ 43.,  53.,  63.],
        [107., 133., 159.]], grad_fn=<MmBackward0>)

Z_partial_1 (on GPU 1):
tensor([[ 727.,  604.,  717.,  676.],
        [1931., 1604., 1905., 1796.]], grad_fn=<MmBackward0>)

Z_partial_2 (on GPU 2):
tensor([[ 540.,  613.,  686.,  759.],
        [1356., 1541., 1726., 1911.]], grad_fn=<MmBackward0>)

Z_unactivated (after All-Reduce):
tensor([[1267., 1217., 1403., 1435.],
        [3287., 3145., 3631., 3707.]], grad_fn=<AddBackward0>)

Final Output Z:
tensor([[2534.,    0.,    0., 2870.],
        [   0., 6290., 7262.,    0.]], grad_fn=<MulBackward0>)
Enter fullscreen mode Exit fullscreen mode

Backward pass

# --------------------------------------------------------------------------
# 3. BACKWARD PASS
# --------------------------------------------------------------------------
print("--- BACKWARD PASS ---")

# Define a dummy loss. Z.sum() creates an initial gradient dZ of all ones.
loss = Z.sum()
print(f"Loss (for backprop):\n{loss}\n")

# Automatically compute gradients for all tensors with requires_grad=True
loss.backward()
Enter fullscreen mode Exit fullscreen mode

Output

--- BACKWARD PASS ---
Loss (for backprop):
18956.0
Enter fullscreen mode Exit fullscreen mode

Verification

# --------------------------------------------------------------------------
# 4. VERIFICATION: Check gradients against manual calculations
# --------------------------------------------------------------------------
print("--- VERIFYING GRADIENTS ---\n")

# Gradient for X (dX)
# This required an All-Reduce in the backward pass
print(f"Gradient for X (dX):\n{X.grad}\n")

# Gradients for A (dA1 and dA2)
# A.grad contains the gradients for the entire A matrix. We can verify
# by slicing it into the parts corresponding to A1 and A2.
print(f"Gradient for A1 (dA1):\n{A.grad[:, :3]}\n")
print(f"Gradient for A2 (dA2):\n{A.grad[:, 3:]}\n")

# Gradients for B (dB1 and dB2)
# Similarly, B.grad contains the full gradient.
print(f"Gradient for B1 (dB1):\n{B.grad[:3, :]}\n")
print(f"Gradient for B2 (dB2):\n{B.grad[3:, :]}\n")
Enter fullscreen mode Exit fullscreen mode

Output

--- VERIFYING GRADIENTS ---

Gradient for X (dX):
tensor([[388., 646., 724., 388.],
        [380., 614., 704., 380.]])

Gradient for A1 (dA1):
tensor([[108., 100., 168.],
        [144., 136., 224.],
        [180., 172., 280.],
        [216., 208., 336.]])

Gradient for A2 (dA2):
tensor([[ 60.,  60., 156.],
        [ 80.,  80., 208.],
        [100., 100., 260.],
        [120., 120., 312.]])

Gradient for B1 (dB1):
tensor([[ 62., 166., 166.,  62.],
        [ 82., 218., 218.,  82.],
        [102., 270., 270., 102.]])

Gradient for B2 (dB2):
tensor([[ 86., 214., 214.,  86.],
        [106., 266., 266., 106.],
        [126., 318., 318., 126.]])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)