hard primitives

Ring All-Reduce

Simulate ring all-reduce β€” the distributed primitive that lets N workers collectively sum their tensors using only ring-local communication.

Why ring topology?

Naive all-reduce (one worker collects everything, then broadcasts) is bandwidth-bottlenecked at the coordinator. A tree reduce cuts the depth to O(log N) but still concentrates traffic. The ring topology is bandwidth-optimal: total data sent per worker is exactly 2 * (N-1)/N * tensor_size, regardless of N. This is the algorithm used by NCCL, Horovod, and PyTorch DDP for large-scale training.

Reference: Patarasuk & Yuan (2009), Bandwidth Optimal All-reduce Algorithms for Clusters of Workstations.

Two-phase algorithm:

Each phase runs for exactly N βˆ’ 1 rounds and communicates one chunk per worker per round.

Phase 1 β€” Scatter-Reduce (N-1 rounds):
  Chunk each worker's tensor along axis 0 into N equal pieces.
  In each round, worker i sends one chunk to its right neighbor (i+1 mod N),
  who accumulates it (adds it to the same chunk slot in its own buffer).
  After N-1 rounds, worker i holds the full sum for exactly one chunk.

Phase 2 β€” All-Gather (N-1 rounds):
  Each worker now broadcasts its completed (fully-summed) chunk around
  the ring. After N-1 rounds, every worker holds every chunk β€” the full
  sum reconstructed.

Simplification for this problem:

worker_tensors.shape[0] (the leading dimension of each worker’s tensor) must be divisible by N so that chunks are equal-sized.

Input convention:

worker_tensors is a stacked tensor of shape (N, ...) where worker_tensors[i] is the tensor held by worker i.

Function signature:

def ring_all_reduce(worker_tensors):
    """Simulate ring-all-reduce sum across N workers.

    worker_tensors: shape (N, ...) β€” N workers, each holding a tensor of
                    identical shape. worker_tensors[i] is worker i's tensor.

    Returns: tensor of shape (...) β€” the SUM of all worker_tensors along
             the worker (first) dimension.

    Requirement: worker_tensors.shape[1] must be divisible by N.
    """

Hints

distributed all-reduce training

Sign in to attempt this problem and view the solution.