hard end_to_end

Distributed Training Step End-to-End

The capstone of the Production-ML track: implement one distributed training step by composing gradient accumulation, all-reduce, Adam, and EMA โ€” entirely by hand.

No optim.Adam. No library all-reduce. You own the entire pipeline.

Overview

In real distributed training, N workers each hold a shard of the data. Each worker:

  1. Accumulates gradients over several micro-batches (to handle memory limits).
  2. All-reduces those gradients across workers (so every worker sees the same global gradient).
  3. Applies an optimizer step (Adam with bias correction).
  4. Updates an EMA (exponential moving average) of the weights for a smoother evaluation model.

The model

A linear regressor with MSE loss. Per micro-batch (x, y) with x of shape (B, d) and y of shape (B,), the gradient is:

$$\nabla_w = \frac{2}{B} X^\top (Xw - y)$$

Pipeline

Step 1 โ€” Gradient accumulation (per worker):

worker_grad = zeros(d)
for k in range(accum_steps):
    grad_k = (2/B) * x[k].T @ (x[k] @ w - y[k])
    worker_grad += grad_k
worker_grad /= accum_steps

Step 2 โ€” All-reduce (across workers):

averaged_grad = sum(worker_grads) / N_workers

This simulates the mean-reduce semantics of NCCL AllReduce.

Step 3 โ€” Adam update (with bias correction):

$$m = \beta_1 m + (1 - \beta_1) g$$ $$v = \beta_2 v + (1 - \beta_2) g^2$$ $$\hat{m} = m / (1 - \beta_1^t), \quad \hat{v} = v / (1 - \beta_2^t)$$ $$w = w - \eta\, \hat{m} / (\sqrt{\hat{v}} + \epsilon)$$

Step 4 โ€” EMA update:

$$w_{\text{ema}} = \delta\, w_{\text{ema}} + (1 - \delta)\, w$$

Inputs

  • worker_xs: shape (N_workers, accum_steps, B, d).
  • worker_ys: shape (N_workers, accum_steps, B).
  • weights: shape (d,) โ€” current model parameters.
  • ema_weights: shape (d,) โ€” current EMA weights.
  • m, v: shape (d,) โ€” Adam first/second moments.
  • lr, beta1, beta2, eps: Adam hyperparameters.
  • decay: EMA decay coefficient.
  • t: int โ€” current step (1-indexed for bias correction).

Output

Returns shape (4*d,) โ€” concatenation of (new_weights, new_ema_weights, new_m, new_v) flat.

Note

This is a simulated distributed system โ€” worker_xs[i] represents what worker i has. Real distributed training uses NCCL/MPI for the communication; this problem is the algorithmic core.

Hints

training distributed ema gradient-accumulation

Sign in to attempt this problem and view the solution.