hard primitives

Training Loop via lax.scan

Why this matters

In production JAX code you almost never write a Python for loop over training steps — you scan them. lax.scan compiles the entire training loop into a single XLA program: one JIT boundary, one call to the accelerator, no per-step Python overhead.

The canonical pattern is: the carry is the model weight (or full parameter pytree), the xs are the microbatches stacked along axis 0. The body computes a gradient and returns the updated weight. This is the foundation that Flax’s train_state and Optax’s chain build on.

Worked mini-example

import jax, jax.numpy as jnp
from jax import lax

def loss(w, x, y):
    return jnp.sum((w * x - y) ** 2)

grad_fn = jax.grad(loss)

def step(w, batch):
    x, y = batch
    g = grad_fn(w, x, y)
    return w - 0.1 * g, None   # (new_carry, unused_output)

w_final, _ = lax.scan(step, 0.0, (x_batches, y_batches))

Key: pass (x_batches, y_batches) as xs. JAX unpacks the tuple and delivers (x_batches[i], y_batches[i]) to the body as batch on each step.

Common pitfalls

  • Forgetting to unpack the batch. Inside the body, batch is a tuple (x, y). You must write x, y = batch before using them — missing this means you pass a tuple where a jax array is expected.
  • lr captured wrong. lr is a scalar that lives outside the body; JAX captures it as a constant. This is fine — no need to pass it as part of the carry or xs.
  • Wrong per-step output. When you don’t need per-step outputs, return None as the second element. lax.scan still needs the 2-tuple (new_carry, output).

Problem

Implement scan_training_step(w, lr, x_batches, y_batches) using lax.scan. Minimise the loss sum((w*x - y)^2) over N microbatches with learning rate lr. Return the scalar final weight.

  • w: scalar.
  • lr: scalar.
  • x_batches: 2-D jax array (N, B).
  • y_batches: 2-D jax array (N, B).
  • Returns: scalar — final w after N gradient-descent updates.

Hints

jax scan training

Sign in to attempt this problem and view the solution.