We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Ring All-Reduce
Simulate ring all-reduce β the distributed primitive that lets N workers collectively sum their tensors using only ring-local communication.
Why ring topology?
Naive all-reduce (one worker collects everything, then broadcasts) is
bandwidth-bottlenecked at the coordinator. A tree reduce cuts the depth to
O(log N) but still concentrates traffic. The ring topology is
bandwidth-optimal: total data sent per worker is exactly
2 * (N-1)/N * tensor_size, regardless of N. This is the algorithm used by
NCCL, Horovod, and PyTorch DDP for large-scale training.
Reference: Patarasuk & Yuan (2009), Bandwidth Optimal All-reduce Algorithms for Clusters of Workstations.
Two-phase algorithm:
Each phase runs for exactly N β 1 rounds and communicates one chunk per worker per round.
Phase 1 β Scatter-Reduce (N-1 rounds):
Chunk each worker's tensor along axis 0 into N equal pieces.
In each round, worker i sends one chunk to its right neighbor (i+1 mod N),
who accumulates it (adds it to the same chunk slot in its own buffer).
After N-1 rounds, worker i holds the full sum for exactly one chunk.
Phase 2 β All-Gather (N-1 rounds):
Each worker now broadcasts its completed (fully-summed) chunk around
the ring. After N-1 rounds, every worker holds every chunk β the full
sum reconstructed.
Simplification for this problem:
worker_tensors.shape[0] (the leading dimension of each workerβs tensor)
must be divisible by N so that chunks are equal-sized.
Input convention:
worker_tensors is a stacked tensor of shape (N, ...) where
worker_tensors[i] is the tensor held by worker i.
Function signature:
def ring_all_reduce(worker_tensors):
"""Simulate ring-all-reduce sum across N workers.
worker_tensors: shape (N, ...) β N workers, each holding a tensor of
identical shape. worker_tensors[i] is worker i's tensor.
Returns: tensor of shape (...) β the SUM of all worker_tensors along
the worker (first) dimension.
Requirement: worker_tensors.shape[1] must be divisible by N.
"""
Hints
Sign in to attempt this problem and view the solution.