We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Distributed Training Step End-to-End
The capstone of the Production-ML track: implement one distributed training step by composing gradient accumulation, all-reduce, Adam, and EMA โ entirely by hand.
No optim.Adam. No library all-reduce. You own the entire pipeline.
Overview
In real distributed training, N workers each hold a shard of the data. Each worker:
- Accumulates gradients over several micro-batches (to handle memory limits).
- All-reduces those gradients across workers (so every worker sees the same global gradient).
- Applies an optimizer step (Adam with bias correction).
- Updates an EMA (exponential moving average) of the weights for a smoother evaluation model.
The model
A linear regressor with MSE loss. Per micro-batch (x, y) with x of
shape (B, d) and y of shape (B,), the gradient is:
$$\nabla_w = \frac{2}{B} X^\top (Xw - y)$$
Pipeline
Step 1 โ Gradient accumulation (per worker):
worker_grad = zeros(d)
for k in range(accum_steps):
grad_k = (2/B) * x[k].T @ (x[k] @ w - y[k])
worker_grad += grad_k
worker_grad /= accum_steps
Step 2 โ All-reduce (across workers):
averaged_grad = sum(worker_grads) / N_workers
This simulates the mean-reduce semantics of NCCL AllReduce.
Step 3 โ Adam update (with bias correction):
$$m = \beta_1 m + (1 - \beta_1) g$$ $$v = \beta_2 v + (1 - \beta_2) g^2$$ $$\hat{m} = m / (1 - \beta_1^t), \quad \hat{v} = v / (1 - \beta_2^t)$$ $$w = w - \eta\, \hat{m} / (\sqrt{\hat{v}} + \epsilon)$$
Step 4 โ EMA update:
$$w_{\text{ema}} = \delta\, w_{\text{ema}} + (1 - \delta)\, w$$
Inputs
-
worker_xs: shape(N_workers, accum_steps, B, d). -
worker_ys: shape(N_workers, accum_steps, B). -
weights: shape(d,)โ current model parameters. -
ema_weights: shape(d,)โ current EMA weights. -
m,v: shape(d,)โ Adam first/second moments. -
lr,beta1,beta2,eps: Adam hyperparameters. -
decay: EMA decay coefficient. -
t: int โ current step (1-indexed for bias correction).
Output
Returns shape (4*d,) โ concatenation of
(new_weights, new_ema_weights, new_m, new_v) flat.
Note
This is a simulated distributed system โ worker_xs[i] represents
what worker i has. Real distributed training uses NCCL/MPI for the
communication; this problem is the algorithmic core.
Hints
Sign in to attempt this problem and view the solution.