medium primitives

Checkpointed Deep Stack via scan

Why this matters

Deep stacks of identical layers โ€” transformer blocks, ResNet stages, SSM recurrences โ€” are the canonical use case for gradient checkpointing. For an N-layer network:

  • Without checkpointing: O(N) activation memory (one stored tensor per layer).
  • With scan + checkpoint: O(1) activation memory โ€” only the current carry is live at any time, and the per-step activations are recomputed during the backward pass.

lax.scan provides the loop primitive. Wrapping its body function with jax.checkpoint gives you the recompute-on-demand behaviour for every scan iteration. The combination is the standard pattern for training very deep or very long models.

Worked mini-example

import jax
import jax.numpy as jnp
from jax import lax

weights = jnp.stack([jnp.eye(2)] * 3)   # 3 identity layers, shape (3,2,2)
x = jnp.ones((1, 2))

# Non-checkpointed body โ€” activations retained across all 3 iterations.
def layer_plain(y, w):
    return jax.nn.relu(y @ w), None

# Checkpointed body โ€” activations of each step are dropped and recomputed.
@jax.checkpoint
def layer_ckpt(y, w):
    return jax.nn.relu(y @ w), None

def loss(weights, x, body):
    y_final, _ = lax.scan(body, x, weights)
    return jnp.sum(y_final ** 2)

g_plain = jax.grad(loss, argnums=0)(weights, x, layer_plain)
g_ckpt  = jax.grad(loss, argnums=0)(weights, x, layer_ckpt)

assert jnp.allclose(g_plain, g_ckpt)   # True โ€” gradients are identical

Common pitfalls

  • lax.scan body signature โ€” body(carry, x_slice) -> (new_carry, output). Here: layer(y, w) -> (new_y, None). The None is the per-step output (stacked across steps); we discard it.
  • Wrap the BODY, not the whole scan โ€” jax.checkpoint(lax.scan(...)) is wrong; that would checkpoint the entire unrolled computation. Wrap the per-step function instead.
  • @jax.checkpoint as a decorator is equivalent to jax.checkpoint(fn) โ€” both are valid.
  • Very deep stacks (>100) โ€” this combination is essential; without it memory grows linearly with depth.

Problem

Implement grad_checkpointed_stack(weights, x) that:

  1. Defines a per-layer body layer(y, w) = (relu(y @ w), None) wrapped with jax.checkpoint.
  2. Applies it across N=8 layers via lax.scan.
  3. Computes loss = sum(y_final ** 2).
  4. Returns jax.grad(loss, argnums=0)(weights, x) โ€” the gradient w.r.t. weights.
  • weights: 3-D jax array (8, d, d).
  • x: 2-D jax array (N_batch, d).

Returns: array same shape as weights โ€” gradient of the loss w.r.t. each layer weight matrix.

Hints

jax checkpoint scan

Sign in to attempt this problem and view the solution.