We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
Checkpointed Deep Stack via scan
Why this matters
Deep stacks of identical layers โ transformer blocks, ResNet stages, SSM recurrences โ are the canonical use case for gradient checkpointing. For an N-layer network:
- Without checkpointing: O(N) activation memory (one stored tensor per layer).
- With scan + checkpoint: O(1) activation memory โ only the current carry is live at any time, and the per-step activations are recomputed during the backward pass.
lax.scan provides the loop primitive. Wrapping its body function with
jax.checkpoint gives you the recompute-on-demand behaviour for every
scan iteration. The combination is the standard pattern for training very
deep or very long models.
Worked mini-example
import jax
import jax.numpy as jnp
from jax import lax
weights = jnp.stack([jnp.eye(2)] * 3) # 3 identity layers, shape (3,2,2)
x = jnp.ones((1, 2))
# Non-checkpointed body โ activations retained across all 3 iterations.
def layer_plain(y, w):
return jax.nn.relu(y @ w), None
# Checkpointed body โ activations of each step are dropped and recomputed.
@jax.checkpoint
def layer_ckpt(y, w):
return jax.nn.relu(y @ w), None
def loss(weights, x, body):
y_final, _ = lax.scan(body, x, weights)
return jnp.sum(y_final ** 2)
g_plain = jax.grad(loss, argnums=0)(weights, x, layer_plain)
g_ckpt = jax.grad(loss, argnums=0)(weights, x, layer_ckpt)
assert jnp.allclose(g_plain, g_ckpt) # True โ gradients are identical
Common pitfalls
-
lax.scanbody signature โbody(carry, x_slice) -> (new_carry, output). Here:layer(y, w) -> (new_y, None). TheNoneis the per-step output (stacked across steps); we discard it. -
Wrap the BODY, not the whole scan โ
jax.checkpoint(lax.scan(...))is wrong; that would checkpoint the entire unrolled computation. Wrap the per-step function instead. -
@jax.checkpointas a decorator is equivalent tojax.checkpoint(fn)โ both are valid. - Very deep stacks (>100) โ this combination is essential; without it memory grows linearly with depth.
Problem
Implement grad_checkpointed_stack(weights, x) that:
-
Defines a per-layer body
layer(y, w) = (relu(y @ w), None)wrapped withjax.checkpoint. -
Applies it across N=8 layers via
lax.scan. -
Computes
loss = sum(y_final ** 2). -
Returns
jax.grad(loss, argnums=0)(weights, x)โ the gradient w.r.t.weights.
-
weights: 3-D jax array(8, d, d). -
x: 2-D jax array(N_batch, d).
Returns: array same shape as weights โ gradient of the loss w.r.t. each
layer weight matrix.
Hints
jax
checkpoint
scan
Sign in to attempt this problem and view the solution.