We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
primitives
4-Step Training Loop with Scan + Loss Curve
Why this matters
A production JAX training loop compiles the entire multi-step iteration
into a single lax.scan, making it jit-able and XLA-fuseable with no
Python overhead per step. The carry holds mutable state (weight + opt
state); the xs are the per-step batches; the ys are the per-step losses
collected for a loss curve.
The recipe
optimizer = optax.sgd(lr)
opt_state = optimizer.init(w0)
def loss_fn(w, x, y):
return jnp.sum((w * x - y) ** 2)
grad_fn = jax.grad(loss_fn, argnums=0)
def step(carry, batch):
w, opt_state = carry
x, y = batch
l = loss_fn(w, x, y) # pre-update loss
g = grad_fn(w, x, y)
updates, opt_state = optimizer.update(g, opt_state, w)
new_w = optax.apply_updates(w, updates)
return (new_w, opt_state), l # new carry, stacked output
(_, _), losses = lax.scan(step, (w0, opt_state), (x_batches, y_batches))
return losses
Common pitfalls
-
opt_statemust be in the carry; closing over it means it never updates across steps โ a common source of incorrect loss curves. - Per-step loss is captured before the parameter update.
-
lax.scanexpects the xs to be stacked along axis 0; here that is the first dimension ofx_batchesandy_batches.
Inputs
-
w0: scalar โ initial weight. -
x_batches: 2-D JAX array of shape(4, B)โ per-step inputs. -
y_batches: 2-D JAX array of shape(4, B)โ per-step targets. -
lr: scalar โ SGD learning rate.
Output
1-D array of shape (4,) โ pre-update loss at each of the 4 steps.
Hints
optax
training
scan
loss-curve
Sign in to attempt this problem and view the solution.