We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Scan with Per-Step Outputs
Why this matters
lax.scan returns two things: (final_carry, stacked_outputs). The
previous problem discarded the outputs with _. Collecting per-step
quantities โ losses, activations, attention weights โ is the natural next
step and removes the need for a Python list that defeats JIT.
The rule is simple: whatever you emit as the second element of the bodyโs
return tuple, JAX stacks along a new leading dimension. If the body emits
a scalar, the output is shape (N,). If it emits a vector of length d,
the output is shape (N, d). The shape must be identical across all
steps โ scan compiles the body once and assumes static shapes.
Worked mini-example
from jax import lax
import jax.numpy as jnp
def step(w, batch):
x, y = batch
l = jnp.sum((w * x - y) ** 2) # loss BEFORE update
g = jax.grad(loss_fn)(w, x, y)
return w - 0.1 * g, l # (new_carry, per-step output)
final_w, losses = lax.scan(step, w0, (xs, ys))
# losses.shape == (N,) โ one loss per microbatch
Note: the loss is computed with the current w (before the update),
so losses[i] is the pre-update loss at step i.
Common pitfalls
-
Loss after update. Computing the loss with
new_winstead ofwgives the loss one step ahead โ wrong semantics and harder to interpret. - Inconsistent output shape. If your body emits different shapes on different iterations, scan will error at trace time. Keep it static.
-
Forgetting
jnp.atleast_1don the scalar carry.final_wis a 0-D array;jnp.concatenaterequires at least 1-D inputs, so wrap it withjnp.atleast_1dbefore concatenating.
Problem
Implement scan_collect_losses(w, lr, x_batches, y_batches) using
lax.scan. Like scan_training_step, run N gradient-descent steps on
loss = sum((w*x - y)^2). Also collect the pre-update loss at each
step.
-
Returns: 1-D array of shape
(N+1,)โ[final_w, loss_step_0, ..., loss_step_{N-1}].
Pack the output as
jnp.concatenate([jnp.atleast_1d(final_w), losses]).
Hints
Sign in to attempt this problem and view the solution.