Table of Contents
- Motivation
- What are collective operations
- Communication load and time-space tradeoff
- How does data sharding affect time-space-communication tradeoff
- Rank vs node
- Common collective operations
- Broadcast
- Scatter
- Gather
- Reduce
- All-reduce
- Reduce-scatter
- All-gather
- How is all-gather different from gather?
- Why Reduce-Scatter + All-gather = All-reduce
- Appendix A: Non-blocking all-gather
Motivation
Collective operations are fundamental to distributed computing, but it was not easy to read documentation such as NCCL to fully understand such operations.
This article will instead work through simple examples for each of the more common operations, which I believe is the quickest way to understand these operations.
As the intent of this article is to understand the operations at a high level, I will not be going into optimization patterns such as tree-based or ring-based algorithms to maximize parallelism and minimize network contention. If you would like to go into further details, you may refer to resources such as the MPI tutorial.
This article was written with the assistance of Google Gemini 2.5 Pro.
What are collective operations
In parallel and distributed computing, collective operations are specialized functions that involve communication among a group of processes working together on a task. These processes are identified by a unique integer ID called a rank. Instead of a single process sending a message to another single process (which is called point-to-point communication), collective operations are coordinated actions that all ranks in a defined group must participate in.
This group of processes is formally known as a communicator. While programs often start with a default communicator that includes all processes (like MPI_COMM_WORLD in MPI), it's possible to create smaller subgroups or communicators. This allows for collective operations to be performed on just a subset of the total available processes, which is a powerful feature for complex parallel algorithms. (Source: MPI_COMM_WORLD)
Collective operations are fundamental building blocks in parallel programming libraries like the Message Passing Interface (MPI) and NVIDIA's NCCL for GPU-based computing. These libraries provide optimized implementations of collective operations, so developers do not have to implement them from scratch.
Communication load and time-space tradeoff
In traditional algorithm analysis, time complexity quantifies the number of operations an algorithm performs, while space complexity measures the amount of memory it requires. However, in distributed systems, where data and computation are spread across multiple nodes, the cost of communication could become a dominant factor in overall performance.
Communication cost can be broken down into:
- Latency (α): The time it takes to initiate a message transfer, regardless of its size. This is often influenced by factors like network protocol overhead and routing.
- Bandwidth Cost (β): The cost per unit of data (e.g., per byte) to transmit the message over the network.
Other considerations to manage communication cost include:
- Network Topology: The physical and logical arrangement of nodes (e.g., ring, mesh, tree, fat-tree) has a profound impact on the performance of collective operations. An algorithm that is optimal on a fully connected network may perform poorly on a network with high contention for certain links.
- Contention: When multiple messages try to traverse the same network link simultaneously, contention occurs, leading to delays. This is a critical factor in real-world performance that is not captured by the simple α-β model.
-
Synchronization Overhead: Collective operations are blocking, meaning all processes in the group must wait for the operation to complete before proceeding. This synchronization can introduce significant idle time, especially if there are "straggler" nodes that are slower than others. (Note: Many libraries also provide non-blocking variants, which allow a process to initiate the collective operation and then immediately proceed with other tasks, checking for the operation's completion later. This technique is key to hiding network latency by overlapping communication with computation. See Appendix A on
Non-blocking all-gather
for an example.) - In-network Computation: Modern networking hardware can sometimes perform reduction operations directly within the network switches, which can dramatically reduce the amount of data that needs to be transmitted and lower the overall latency.
How does data sharding affect time-space-communication tradeoff
- Time Complexity: By sharding the data, each node is responsible for processing a smaller portion of the overall dataset. This can significantly reduce the local computation time on each node. For example, if a computation has a time complexity of O(N) on a single machine, distributing the data across 'P' processors could ideally reduce the computation time on each processor to O(N/P).
- Space Complexity: Similarly, sharding reduces the amount of memory required on each individual node, as each only needs to store its assigned shard of the data.
- Communication Load: However, increasing the number of shards (and thus the number of participating nodes) often increases the communication overhead. Collective operations require coordination and data exchange between nodes. For instance, in an all-reduce operation, as the number of nodes increases, more messages may need to be sent across the network, and the complexity of coordinating these messages can grow.
Rank vs node
Think of a node as a physical, standalone computer. In the context of high-performance computing (HPC), a cluster is made up of many of these nodes connected by a network. Each node typically has its own:
- Processors (CPUs) with multiple cores
- Memory (RAM)
- Operating System
- Potentially one or more GPUs (as is common with NVIDIA systems)
A rank is a unique integer ID assigned to a specific process in a parallel computing job. When you launch a parallel application (for example, using a framework like MPI or a library like NCCL), you specify how many parallel processes you want to run. Each of these processes is assigned a rank, starting from 0. Within a communication group (e.g. MPI's MPI_COMM_WORLD
), each process has a unique rank. This allows the user to send and receive messages to and from specific processes.
Multiple ranks could run on a single node, or may also be distributed across multiple nodes. For example, ranks 0 and 1 could be on NodeA, and ranks 2 and 3 on Node B.
What about communication cost difference between ranks and nodes?
While sharding inherently increases the need for communication, the actual overhead experienced by individual ranks is heavily influenced by their physical location within the computing cluster.
- Intra-Node Communication: Multiple ranks can execute on a single node. Communication between these ranks is generally faster and more efficient as it occurs within the same machine, often over high-speed interconnects like NVLink for GPUs.
- Inter-Node Communication: When ranks on different nodes need to communicate, the data must traverse the network, which introduces significantly higher latency and lower bandwidth compared to intra-node communication.
Therefore, how data sharding impacts communication load is not uniform across all ranks. The placement of ranks across nodes becomes a critical factor. A well-designed sharding strategy will aim to minimize the frequency and volume of inter-node communication by keeping data that needs to be processed together on the same node, thus allowing the corresponding ranks to communicate more efficiently.
Common collective operations
Broadcast
A broadcast operation sends the same piece of data from one designated "root" rank to all other ranks in the group.
Source: Nvidia NCCL documentation
The formula out[i] = in[i]
is self-explanatory. Imagine Rank 2 has a single, important value, say the number 42
, that all other ranks need for a calculation.
-
Initial State:
- Rank 0: has no data
- Rank 1: has no data
- Rank 2: has data
42
- Rank 3: has no data
-
Step-by-Step Process:
- Initiation: Rank 2 is designated as the root and initiates the broadcast.
- Transmission: Rank 2 sends the value
42
to all other ranks (Rank 0, Rank 1, and Rank 3). In practice, this might happen through a direct send to each or through a more optimized network pattern like a tree, where Rank 2 sends to Rank 0 and Rank 1, and they in turn forward it to others. - Reception: Each of the other ranks receives the value
42
.
-
Final State:
- Rank 0: has data
42
- Rank 1: has data
42
- Rank 2: has data
42
- Rank 3: has data
42
- Rank 0: has data
Analogy: A manager (Rank 2) sending the same memo (42
) to every employee (the other ranks) in a department.
Scatter
A scatter operation takes an array of data on a single root rank and distributes different chunks of that array to the other ranks in the group. The data distribution for a scatter operation can be described by the following formula:
outY[i] = in[Y*count+i]
This formula shows how each destination rank Y
fills its output buffer outY
with data from the root's single input buffer in
.
Explanation of the Formula
This formula describes a direct data copy from a specific segment of the root's large buffer to the smaller buffer of each destination rank.
in[...]
: This is the source of the data. It refers to the single, large input buffer that exists only on the root rank.outY[i]
: This is the destination for the data. It represents the i-th element of the output buffer on a specific destination rankY
. Each rank, including the root, will have its ownoutY
buffer after the operation.-
[Y*count+i]
: This is the source index. It's the "selection logic" that determines which chunk of the root'sin
buffer is sent to which rank.-
Y
: The rank of the destination process that will receive the data chunk. -
count
: The number of elements in each distributed chunk (i.e., the size of eachoutY
buffer). -
i
: The local index within that chunk.
-
In essence, the root rank uses the formula to slice its large in
buffer. The chunk corresponding to Y=0
goes to rank 0, the chunk for Y=1
goes to rank 1, and so on, until the entire input buffer has been distributed.
Concrete Example:
Imagine Rank 0 holds an array of data [10, 20, 30, 40]
, and it wants to give each rank one element to work on.
-
Initial State:
- Rank 0: has no data
- Rank 1: has no data
- Rank 2: has data
[10, 20, 30, 40]
- Rank 3: has no data
-
Step-by-Step Process:
- Initiation: Rank 2, as the root, prepares its array for distribution. It conceptually splits the array into chunks, one for each rank (including itself).
- Distribution: Rank 2 sends a unique chunk to each corresponding rank. It sends
10
to Rank 0,20
to Rank 1, and40
to Rank 3. It keeps the third chunk,30
, for itself. Using Rank 0 as an example,Y = 0
because Rank 0.count = 1
because each rank gets 1 integer.i = 0
because we want the first integer. So the value that rank 0 gets out ofdata
isdata[0*1+0]
=data[0]
= 10. - Reception: Each rank receives its assigned piece of the data.
-
Final State:
- Rank 0: has data
10
- Rank 1: has data
20
- Rank 2: has data
30
- Rank 3: has data
40
- Rank 0: has data
Analogy: A teacher (Rank 2) with a stack of different exam questions [10, 20, 30, 40]
giving one unique question to each student (the other ranks).
Gather
A gather operation is the inverse of a scatter. It collects individual data values from all ranks in the group and assembles them into an array on a single designated root rank.
The data placement for a gather operation can be described with the following formula, which specifies how the root rank constructs its output buffer:
out[Y*count+i] = inY[i]
(This is valid on the root rank only)
This formula describes a direct data copy from each rank's input buffer to a specific location in the root rank's output buffer.
inY[i]
: This is the source of the data. It represents the i-th element from the input buffer of a specific source rankY
.out[...]
: This is the destination for the data. It refers to the large output buffer that is being constructed only on the designated root rank. This is the key difference betweenGather
andAll-Gather
.-
[Y*count+i]
: This is the destination index within the root'sout
buffer. It acts as the "placement logic" to ensure data is assembled correctly.-
Y
: The rank of the source process whose data is being placed. -
count
: The number of elements each rank is contributing (the size of eachinY
buffer). -
i
: The local index within that source chunk.
-
In essence, the formula states that the root rank iterates through each source rank Y
, takes its input data inY
, and copies it into the correct block of the final out
buffer.
Concrete Example:
Following the scatter example, let's say each rank has performed a calculation on its data and now has a result. Rank 0 needs to collect all these results.
-
Initial State:
- Rank 0: has result
A
- Rank 1: has result
B
- Rank 2: has result
C
- Rank 3: has result
D
- Rank 0: has result
-
Step-by-Step Process:
- Initiation: Rank 2 is designated as the root rank that will receive all the data.
- Transmission: Each rank sends its individual data value to Rank 2. Rank 0 sends
A
, Rank 1 sendsB
, and Rank 3 sendsD
. - Collection: Rank 2 receives the data from all ranks and assembles it into an array in a specific order (usually based on rank ID).
-
Final State:
- Rank 0: still has its data
A
- Rank 1: still has its data
B
- Rank 2: has data
[A, B, C, D]
- Rank 3: still has its data
D
- Rank 0: still has its data
Analogy: A teacher (Rank 2) collecting the completed, unique exam answers (A, B, C, D
) from every student (the other ranks) to grade them.
Reduce
Source: Nvidia NCCL documentation
A reduce operation is similar to a gather, but as it collects data from all ranks, it combines them into a single final value using a specified operation (like sum, max, min, or logical AND). This final result is stored on only one root rank.
The formula out[i] = sum(inX[i])
describes an element-wise reduction across all participating ranks.
-
inX[i]
: This refers to the i-th element of the input buffer from a specific rankX
.-
in0[i]
is the i-th element from rank 0's input buffer. -
in1[i]
is the i-th element from rank 1's input buffer. -
in2[i]
is the i-th element from rank 2's input buffer. -
in3[i]
is the i-th element from rank 3's input buffer. Imagine eachin
buffer is an array of numbers.i
is the index into that array.
-
sum(...)
: This specifies the reduction operation to be performed. In this case, it's a summation. For each indexi
, the operation takes the corresponding elements from all the input buffers and adds them together. So, the calculation is:in0[i] + in1[i] + in2[i] + in3[i]
.out[i]
: This represents the i-th element of the final output buffer. As the diagram shows, this output buffer exists only on the designated root rank (in this case,rank 2
).
Putting It All Together
The expression out[i] = sum(inX[i])
means:
For every position
i
in the data buffers, the value atout[i]
on the root rank is calculated by summing the values at that same positioni
from the input buffers of all ranks (rank 0, rank 1, rank 2, and rank 3).
A Concrete Example
Let's assume each rank has an input buffer with three elements:
- rank 0
in0
:[10, 20, 5]
- rank 1
in1
:[2, 4, 6]
- rank 2
in2
:[1, 1, 1]
- rank 3
in3
:[5, 10, 15]
The root is rank 2
.
After the Reduce
operation, the out
buffer on rank 2
will be calculated as follows:
-
out[0]
=sum(inX[0])
=in0[0] + in1[0] + in2[0] + in3[0]
=10 + 2 + 1 + 5
=18
-
out[1]
=sum(inX[1])
=in0[1] + in1[1] + in2[1] + in3[1]
=20 + 4 + 1 + 10
=35
-
out[2]
=sum(inX[2])
=in0[2] + in1[2] + in2[2] + in3[2]
=5 + 6 + 1 + 15
=27
Final State:
- rank 0, 1, and 3: Their input buffers remain unchanged. They have no
out
buffer. - rank 2 (root): Now has the
out
buffer:[18, 35, 27]
.
Analogy: Team members (ranks) each count the number of widgets they made. They report their counts up a chain of command, with each manager summing the counts from their reports, until the final boss (Rank 0) has the grand total.
It is worth noting that many libraries provide an 'in-place' option for these operations. When operating in-place, the result is stored directly back into the input buffer, which can be more memory-efficient as it avoids the need to allocate a separate output buffer.
All-reduce
Source: Nvidia NCCL documentation
An all-reduce operation is a combination of a reduce and a broadcast. It performs a reduction (combining values from all ranks into a single result), and then it broadcasts that final result back to all the ranks.
The formula out[i] = sum(inX[i])
is identical to the one for Reduce, but its implication is different because of where the output buffer (out
) is stored.
-
inX[i]
: This refers to the i-th element of the input buffer from a specific rankX
.-
in0[i]
is the i-th element from rank 0. -
in1[i]
is the i-th element from rank 1, and so on.
-
sum(...)
: This is the reduction operation. For each indexi
, it calculates the sum of the elements at that position from all ranks:in0[i] + in1[i] + in2[i] + in3[i]
.out[i]
: This is the crucial difference from a standardReduce
.out[i]
represents the i-th element of the output buffer, and as the diagram clearly shows, thisout
buffer is present on every single rank (rank 0
,rank 1
,rank 2
, andrank 3
).
So, the expression out[i] = sum(inX[i])
for an All-Reduce means:
For every position
i
, calculate the sum of the values at that position from the input buffers of all ranks, and place this final sum into theout[i]
position on every participating rank.
A Concrete Example
Let's use the same initial data as the Reduce
example:
- rank 0
in0
:[10, 20, 5]
- rank 1
in1
:[2, 4, 6]
- rank 2
in2
:[1, 1, 1]
- rank 3
in3
:[5, 10, 15]
The intermediate Reduce
calculation is the same:
-
sum(inX[0])
=10 + 2 + 1 + 5
=18
-
sum(inX[1])
=20 + 4 + 1 + 10
=35
-
sum(inX[2])
=5 + 6 + 1 + 15
=27
The final, reduced buffer is [18, 35, 27]
.
Now, the Broadcast
phase ensures this result is distributed.
Final State:
After the All-Reduce
operation, every single rank will have an identical out
buffer:
- rank 0
out
:[18, 35, 27]
- rank 1
out
:[18, 35, 27]
- rank 2
out
:[18, 35, 27]
- rank 3
out
:[18, 35, 27]
Analogy: After all team members (ranks) report their individual widget counts up the chain to get a grand total, the final boss (who now knows the total) sends out a company-wide email announcing the total number of widgets produced so that every employee is aware of the final number.
Reduce-Scatter
Source: Nvidia NCCL documentation
A reduce-scatter operation first combines data from all ranks using a specified operation (like a reduce) and then scatters the resulting combined data chunks back to the ranks. So, each rank ends up with a piece of the final result.
The formula outY[i] = sum(inX[Y*count+i])
combines the reduce and scatter steps into a single mathematical expression.
-
outY[i]
: This represents the i-th element of the output buffer on a specific destination rankY
.-
out0[i]
is the i-th element of the output buffer on rank 0. -
out1[i]
is the i-th element of the output buffer on rank 1. - ...and so on. This matches the diagram where each rank
Y
gets its ownoutY
buffer.
-
sum(...)
: This is the reduction operation (e.g.,sum
,max
,min
). It combines data from all the input ranks.inX[...]
: This refers to the input buffer on rankX
, where thesum
is performed over all possible values ofX
.-
[Y*count+i]
: The index into the large input buffers. It maps the local indexi
on the output rankY
to a global index in the input data.-
Y
: The rank of the process that will receive this specific piece of data. -
count
: The number of elements in each final output chunk (i.e., the size ofout0
,out1
, etc.). In the diagram, if eachout
buffer has, for example, 10 elements, thencount
would be 10. -
i
: The local index within a final output chunk. It ranges from0
tocount-1
.
-
The expression Y*count+i
is an offset calculation. For rank 0 (Y=0
), it calculates the sum for the first count
elements. For rank 1 (Y=1
), it calculates the sum for the next count
elements, and so on.
Putting It All Together
The formula outY[i] = sum(inX[Y*count+i])
means:
The value of the i-th element on the output buffer of rank Y is the sum of all input elements at the global index
(Y * count + i)
.
A Concrete Example
Let's assume there are 4 ranks and the goal is for each rank to receive 1 final value (so count = 1
). This means the total result has 4 elements, so each input buffer must also have 4 elements.
Initial State:
- rank 0
in0
:[10, 20, 5, 1]
- rank 1
in1
:[2, 4, 6, 2]
- rank 2
in2
:[1, 1, 1, 3]
- rank 3
in3
:[5, 10, 15, 4]
The Calculation:
-
For
rank 0
(Y=0
): It needs to calculate its output buffer,out0
. Sincecount=1
,i
can only be 0.-
out0[0]
=sum(inX[0*1+0])
=sum(inX[0])
-
out0[0]
=in0[0] + in1[0] + in2[0] + in3[0]
=10 + 2 + 1 + 5
=18
-
-
For
rank 1
(Y=1
): It needs to calculateout1
. Again,i=0
.-
out1[0]
=sum(inX[1*1+0])
=sum(inX[1])
-
out1[0]
=in0[1] + in1[1] + in2[1] + in3[1]
=20 + 4 + 1 + 10
=35
-
-
For
rank 2
(Y=2
): It needs to calculateout2
.-
out2[0]
=sum(inX[2*1+0])
=sum(inX[2])
-
out2[0]
=in0[2] + in1[2] + in2[2] + in3[2]
=5 + 6 + 1 + 15
=27
-
-
For
rank 3
(Y=3
): It needs to calculateout3
.-
out3[0]
=sum(inX[3*1+0])
=sum(inX[3])
-
out3[0]
=in0[3] + in1[3] + in2[3] + in3[3]
=1 + 2 + 3 + 4
=10
-
Final State:
- rank 0:
out0
=[18]
- rank 1:
out1
=[35]
- rank 2:
out2
=[27]
- rank 3:
out3
=[10]
Analogy: A group of authors (ranks) are collaboratively writing a book. They first combine their individual draft sections for each chapter into a single, final version (the reduce step). Then, each author takes one of the final chapters to do a final proofread (the scatter step).
All-gather
Source: Nvidia NCCL documentation
An all-gather operation is where every rank collects all the individual data values from all other ranks. Unlike a standard gather where only one root rank receives the data, in an all-gather, every rank ends up with a complete set of the data from all ranks.
The formula out[Y*count+i] = inY[i]
describes a direct data copy and placement, not a mathematical reduction like sum
. It dictates how the smaller input buffers are assembled into the larger output buffer.
-
inY[i]
: This is the source of the data. It refers to the i-th element of the input buffer from a specific source rankY
.-
in0[i]
is the i-th element from rank 0's input. -
in1[i]
is the i-th element from rank 1's input, and so on.
-
out[...]
: This is the destination for the data. It refers to the large output buffer that is being constructed on every rank.-
[Y*count+i]
: This is the destination index. It's the "placement logic" that determines where the source datainY[i]
should be copied to in the finalout
buffer.-
Y
: The rank of the source process whose data is being placed. -
count
: The number of elements in each individual input chunk (i.e., the size ofin0
,in1
, etc.). -
i
: The local index within that source chunk, ranging from0
tocount-1
.
-
This indexing calculation ensures that the input buffers are concatenated in the correct order (in0
first, then in1
, then in2
, etc.).
Putting It All Together
The formula out[Y*count+i] = inY[i]
means:
The data from the i-th position of the input buffer of rank Y is copied directly into the final output buffer at the global index
(Y * count + i)
. This process is repeated for all source ranksY
, and the resultingout
buffer is identical on every rank.
A Concrete Example
Let's assume each rank has an input buffer with two elements (so count = 2
).
Initial State:
- rank 0
in0
:[10, 11]
- rank 1
in1
:[20, 21]
- rank 2
in2
:[30, 31]
- rank 3
in3
:[40, 41]
The All-Gather
operation will construct the out
buffer on all ranks as follows:
-
Placing data from rank 0 (
Y=0
):-
out[0*2+0]
=in0[0]
=>out[0]
=10
-
out[0*2+1]
=in0[1]
=>out[1]
=11
-
-
Placing data from rank 1 (
Y=1
):-
out[1*2+0]
=in1[0]
=>out[2]
=20
-
out[1*2+1]
=in1[1]
=>out[3]
=21
-
-
Placing data from rank 2 (
Y=2
):-
out[2*2+0]
=in2[0]
=>out[4]
=30
-
out[2*2+1]
=in2[1]
=>out[5]
=31
-
-
Placing data from rank 3 (
Y=3
):-
out[3*2+0]
=in3[0]
=>out[6]
=40
-
out[3*2+1]
=in3[1]
=>out[7]
=41
-
Final State:
After the operation, every single rank will have the identical, fully assembled out
buffer:
- rank 0
out
:[10, 11, 20, 21, 30, 31, 40, 41]
- rank 1
out
:[10, 11, 20, 21, 30, 31, 40, 41]
- rank 2
out
:[10, 11, 20, 21, 30, 31, 40, 41]
- rank 3
out
:[10, 11, 20, 21, 30, 31, 40, 41]
Analogy: At a business meeting, each department head (the ranks) has their department's quarterly report (A, B, C, D). They don't just send their reports to the CEO. Instead, they each make copies of their report and pass them around the table until every department head has a complete binder containing the reports from all other departments.
How is All-gather Different from Gather?
The key difference between gather and all-gather lies in who receives the final, complete collection of data.
- In a gather operation, only one designated "root" rank collects the data from all the other ranks. The other ranks only send their data; they do not receive the final assembled array.
- In an all-gather operation, every rank involved in the operation receives the final, complete set of data from all other ranks. It's as if a gather operation happens for every single rank.
Why Reduce-Scatter + All-gather = All-reduce
The equation reduce-scatter + all-gather = all-reduce
illustrates how two more fundamental operations can be combined to create a more complex one. Let's trace this with a concrete example. The goal is an All-reduce with the sum
operation.
1. Initial State: Each rank has an array of data. The size of the array equals the number of ranks.
- Rank 0
in
:[10, 20, 5, 1]
- Rank 1
in
:[2, 4, 6, 2]
- Rank 2
in
:[1, 1, 1, 3]
- Rank 3
in
:[5, 10, 15, 4]
2. Perform a Reduce-Scatter: This operation performs an element-wise sum
and then scatters the resulting array, so each rank gets one piece of the final sum.
- The first element of the result is
10+2+1+5 = 18
. This goes to Rank 0. - The second element of the result is
20+4+1+10 = 35
. This goes to Rank 1. - The third element of the result is
5+6+1+15 = 27
. This goes to Rank 2. - The fourth element of the result is
1+2+3+4 = 10
. This goes to Rank 3.
3. Intermediate State (after Reduce-Scatter): Each rank now holds its portion of the reduced result.
- Rank 0 has:
[18]
- Rank 1 has:
[35]
- Rank 2 has:
[27]
- Rank 3 has:
[10]
4. Perform an All-gather: Now, each rank takes its piece and shares it with all other ranks, assembling the full array everywhere.
- Final State (after All-gather):
- Rank 0 has
[18, 35, 27, 10]
- Rank 1 has
[18, 35, 27, 10]
- Rank 2 has
[18, 35, 27, 10]
- Rank 3 has
[18, 35, 27, 10]
This final state, where every rank holds the complete array of summed values [18, 35, 27, 10]
, is exactly the outcome of a single All-reduce operation.
Any communication cost difference between reduce-scatter + all-gather vs all-reduce?
At a high level, the theoretical communication cost of an efficient ring-based all-reduce is t he same as ring-based reduce-scatter followed by a ring-based all-gather. A ring-based all-reduce is, for all practical purposes, implemented as a fused reduce-scatter and all-gather operation. They are not two different competing algorithms in this context; rather, one is built directly from the other two. By breaking the problem down this way, data can be continuously streamed around the ring of nodes, ensuring that network links are always busy and that each rank is always either sending or receiving data. This pipelining approach is highly efficient because it avoids the distinct synchronization points that would exist if you performed a full Reduce to one node, which then had to perform a full Broadcast.
Appendix A: Non-blocking all-gather
Let us imagine a scenario where each rank needs to gather model weights from all other ranks and then use those weights to calculate a local gradient.
Scenario 1: Blocking All-gather
In this scenario, computation and communication happen in a strict sequence. The time spent waiting for the network is "wasted" time for the CPU.
// Each rank has its local weights in `local_weights_buffer`
// `all_weights_buffer` is the output buffer
1. print("Starting blocking All-gather...")
2. MPI_Allgather(local_weights_buffer, ..., all_weights_buffer, ...); // <-- PROGRAM PAUSES HERE
// All ranks wait here until every rank has received all data.
// The CPU is idle during this time.
3. print("All-gather complete. All weights are now in all_weights_buffer.")
4. print("Now, starting gradient calculation...")
5. calculate_gradients(all_weights_buffer); // <-- Can only start after step 2 is fully done.
Scenario 2: Non-blocking All-gather
By overlapping the independent computation (Step 4) with the communication (started in Step 2), you effectively "hide" the network latency. The total execution time can be significantly reduced.
// Buffers are the same, plus a request handle
MPI_Request request;
1. print("Initiating NON-BLOCKING All-gather...");
2. MPI_Iallgather(local_weights_buffer, ..., all_weights_buffer, ..., &request); // <-- RETURNS IMMEDIATELY
3. print("All-gather is running in the background. Now doing other work...");
4. perform_independent_computation(); // <-- THIS IS THE OVERLAP!
// This could be anything that doesn't need `all_weights_buffer`.
// e.g., preparing data for the *next* iteration.
5. print("Finished other work. Now I need the weights. Waiting for All-gather to complete...");
6. MPI_Wait(&request, ...); // <-- PROGRAM PAUSES HERE, ONLY IF IT'S NOT ALREADY DONE.
// Guarantees that the operation is complete and the buffer is safe to read.
7. print("All-gather is confirmed complete.");
8. calculate_gradients(all_weights_buffer); // <-- Safe to run now.
Top comments (0)