medium primitives

nn.checkpoint (gradient checkpointing)

Why this matters

Backpropagation needs every forward-pass activation to compute gradients. For a deep transformer, that’s a LOT: every attention matrix, every MLP intermediate, every layer norm output gets saved during the forward pass and freed only after the corresponding backward pass uses it.

On GPUs/TPUs with finite memory, this caps how big a model you can fit. The classic example: a 70B parameter model with batch size 8 and sequence length 4096 might run out of memory not because the params are too large but because the activations are.

Gradient checkpointing (a.k.a. rematerialization, remat) trades compute for memory: instead of saving every intermediate, it saves only a few “checkpoints” along the forward pass and recomputes the rest on demand during backward. Memory drops sharply; compute roughly doubles for the recomputed segments.

nn.checkpoint (alias nn.remat) is the Flax-aware lift of jax.checkpoint. Wrap a Module class and gradients computed through it will rematerialize the activations.

Forward-pass equivalence

Critical: nn.checkpoint(MLP)(...) produces the identical forward output as MLP(...). The only difference is what JAX caches during the trace, which only matters under jax.grad / jax.value_and_grad.

So this problem just verifies you wrap correctly and get the same output a non-checkpointed MLP would have produced. The memory-saving behavior happens automatically the moment you backprop through it.

The canonical incantation

class MLP(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.features)(x)
        x = nn.relu(x)
        x = nn.Dense(self.features)(x)
        return x

CkptMLP = nn.checkpoint(MLP)        # lift the class
model = CkptMLP(features=F)
params = model.init(rng, x)
out = model.apply(params, x)

nn.checkpoint is just nn.remat under a friendlier name. Both work identically for this purpose.

When NOT to checkpoint

Wrapping every layer in nn.checkpoint slows training because the whole forward gets re-run during backward. Common practice:

  • Per-block (e.g., one transformer block at a time): a good balance — big memory win, ~33% extra compute.
  • Per-layer-stack (every K layers): finer control.
  • Selective via policy=... (see pos 84): save the cheap stuff, recompute the expensive stuff. Best of both.

Don’t checkpoint a nn.Dense by itself — it’s already cheap; you’d just slow training for no memory gain.

Common pitfalls

  • Wrapping the instance instead of the class: nn.checkpoint(MLP()) doesn’t work; pass the CLASS, then instantiate.
  • Expecting a different forward output: it’s identical. The “checkpointing” only affects what’s saved across the forward/backward boundary.
  • Double-wrapping (nn.checkpoint(nn.checkpoint(MLP))): legal but redundant — checkpointing twice doesn’t save more memory; it just adds layers of recompute.

Problem

Implement nn_checkpoint_forward(seed, x, features):

  1. Define MLP(nn.Module) with features field. In @nn.compact: Dense(features) → relu → Dense(features).
  2. Wrap with nn.checkpoint: CkptMLP = nn.checkpoint(MLP).
  3. Instantiate, init with PRNGKey(seed) and x.
  4. Apply, return flattened output.

The forward result is the same as a non-wrapped MLP — this is purely a structural / API correctness check.

Inputs:

  • seed: int.
  • x: 1-D (D,).
  • features: int F.

Output: 1-D (F,).

Hints

flax nn-checkpoint remat lifted-transforms

Sign in to attempt this problem and view the solution.