medium primitives

Scan with Per-Step Outputs

Why this matters

lax.scan returns two things: (final_carry, stacked_outputs). The previous problem discarded the outputs with _. Collecting per-step quantities โ€” losses, activations, attention weights โ€” is the natural next step and removes the need for a Python list that defeats JIT.

The rule is simple: whatever you emit as the second element of the bodyโ€™s return tuple, JAX stacks along a new leading dimension. If the body emits a scalar, the output is shape (N,). If it emits a vector of length d, the output is shape (N, d). The shape must be identical across all steps โ€” scan compiles the body once and assumes static shapes.

Worked mini-example

from jax import lax
import jax.numpy as jnp

def step(w, batch):
    x, y = batch
    l = jnp.sum((w * x - y) ** 2)   # loss BEFORE update
    g = jax.grad(loss_fn)(w, x, y)
    return w - 0.1 * g, l            # (new_carry, per-step output)

final_w, losses = lax.scan(step, w0, (xs, ys))
# losses.shape == (N,) โ€” one loss per microbatch

Note: the loss is computed with the current w (before the update), so losses[i] is the pre-update loss at step i.

Common pitfalls

  • Loss after update. Computing the loss with new_w instead of w gives the loss one step ahead โ€” wrong semantics and harder to interpret.
  • Inconsistent output shape. If your body emits different shapes on different iterations, scan will error at trace time. Keep it static.
  • Forgetting jnp.atleast_1d on the scalar carry. final_w is a 0-D array; jnp.concatenate requires at least 1-D inputs, so wrap it with jnp.atleast_1d before concatenating.

Problem

Implement scan_collect_losses(w, lr, x_batches, y_batches) using lax.scan. Like scan_training_step, run N gradient-descent steps on loss = sum((w*x - y)^2). Also collect the pre-update loss at each step.

  • Returns: 1-D array of shape (N+1,) โ€” [final_w, loss_step_0, ..., loss_step_{N-1}].

Pack the output as jnp.concatenate([jnp.atleast_1d(final_w), losses]).

Hints

jax scan outputs

Sign in to attempt this problem and view the solution.