hard primitives

4-Step Training Loop with Scan + Loss Curve

Why this matters

A production JAX training loop compiles the entire multi-step iteration into a single lax.scan, making it jit-able and XLA-fuseable with no Python overhead per step. The carry holds mutable state (weight + opt state); the xs are the per-step batches; the ys are the per-step losses collected for a loss curve.

The recipe

optimizer = optax.sgd(lr)
opt_state = optimizer.init(w0)

def loss_fn(w, x, y):
    return jnp.sum((w * x - y) ** 2)
grad_fn = jax.grad(loss_fn, argnums=0)

def step(carry, batch):
    w, opt_state = carry
    x, y = batch
    l = loss_fn(w, x, y)           # pre-update loss
    g = grad_fn(w, x, y)
    updates, opt_state = optimizer.update(g, opt_state, w)
    new_w = optax.apply_updates(w, updates)
    return (new_w, opt_state), l   # new carry, stacked output

(_, _), losses = lax.scan(step, (w0, opt_state), (x_batches, y_batches))
return losses

Common pitfalls

  • opt_state must be in the carry; closing over it means it never updates across steps โ€” a common source of incorrect loss curves.
  • Per-step loss is captured before the parameter update.
  • lax.scan expects the xs to be stacked along axis 0; here that is the first dimension of x_batches and y_batches.

Inputs

  • w0: scalar โ€” initial weight.
  • x_batches: 2-D JAX array of shape (4, B) โ€” per-step inputs.
  • y_batches: 2-D JAX array of shape (4, B) โ€” per-step targets.
  • lr: scalar โ€” SGD learning rate.

Output

1-D array of shape (4,) โ€” pre-update loss at each of the 4 steps.

Hints

optax training scan loss-curve

Sign in to attempt this problem and view the solution.