We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Training Loop via lax.scan
Why this matters
In production JAX code you almost never write a Python for loop over
training steps — you scan them. lax.scan compiles the entire
training loop into a single XLA program: one JIT boundary, one call to the
accelerator, no per-step Python overhead.
The canonical pattern is: the carry is the model weight (or full
parameter pytree), the xs are the microbatches stacked along axis 0.
The body computes a gradient and returns the updated weight. This is the
foundation that Flax’s train_state and Optax’s chain build on.
Worked mini-example
import jax, jax.numpy as jnp
from jax import lax
def loss(w, x, y):
return jnp.sum((w * x - y) ** 2)
grad_fn = jax.grad(loss)
def step(w, batch):
x, y = batch
g = grad_fn(w, x, y)
return w - 0.1 * g, None # (new_carry, unused_output)
w_final, _ = lax.scan(step, 0.0, (x_batches, y_batches))
Key: pass (x_batches, y_batches) as xs. JAX unpacks the tuple and
delivers (x_batches[i], y_batches[i]) to the body as batch on each
step.
Common pitfalls
-
Forgetting to unpack the batch. Inside the body,
batchis a tuple(x, y). You must writex, y = batchbefore using them — missing this means you pass a tuple where a jax array is expected. -
lr captured wrong.
lris a scalar that lives outside the body; JAX captures it as a constant. This is fine — no need to pass it as part of the carry or xs. -
Wrong per-step output. When you don’t need per-step outputs,
return
Noneas the second element.lax.scanstill needs the 2-tuple(new_carry, output).
Problem
Implement scan_training_step(w, lr, x_batches, y_batches) using
lax.scan. Minimise the loss sum((w*x - y)^2) over N microbatches
with learning rate lr. Return the scalar final weight.
-
w: scalar. -
lr: scalar. -
x_batches: 2-D jax array(N, B). -
y_batches: 2-D jax array(N, B). -
Returns: scalar — final
wafter N gradient-descent updates.
Hints
Sign in to attempt this problem and view the solution.