We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
nn.checkpoint (gradient checkpointing)
Why this matters
Backpropagation needs every forward-pass activation to compute gradients. For a deep transformer, that’s a LOT: every attention matrix, every MLP intermediate, every layer norm output gets saved during the forward pass and freed only after the corresponding backward pass uses it.
On GPUs/TPUs with finite memory, this caps how big a model you can fit. The classic example: a 70B parameter model with batch size 8 and sequence length 4096 might run out of memory not because the params are too large but because the activations are.
Gradient checkpointing (a.k.a. rematerialization, remat) trades compute for memory: instead of saving every intermediate, it saves only a few “checkpoints” along the forward pass and recomputes the rest on demand during backward. Memory drops sharply; compute roughly doubles for the recomputed segments.
nn.checkpoint (alias nn.remat) is the Flax-aware lift of
jax.checkpoint. Wrap a Module class and gradients computed
through it will rematerialize the activations.
Forward-pass equivalence
Critical: nn.checkpoint(MLP)(...) produces the identical
forward output as MLP(...). The only difference is what
JAX caches during the trace, which only matters under
jax.grad / jax.value_and_grad.
So this problem just verifies you wrap correctly and get the same output a non-checkpointed MLP would have produced. The memory-saving behavior happens automatically the moment you backprop through it.
The canonical incantation
class MLP(nn.Module):
features: int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.features)(x)
x = nn.relu(x)
x = nn.Dense(self.features)(x)
return x
CkptMLP = nn.checkpoint(MLP) # lift the class
model = CkptMLP(features=F)
params = model.init(rng, x)
out = model.apply(params, x)
nn.checkpoint is just nn.remat under a friendlier name. Both
work identically for this purpose.
When NOT to checkpoint
Wrapping every layer in nn.checkpoint slows training because
the whole forward gets re-run during backward. Common practice:
- Per-block (e.g., one transformer block at a time): a good balance — big memory win, ~33% extra compute.
- Per-layer-stack (every K layers): finer control.
-
Selective via
policy=...(see pos 84): save the cheap stuff, recompute the expensive stuff. Best of both.
Don’t checkpoint a nn.Dense by itself — it’s already cheap;
you’d just slow training for no memory gain.
Common pitfalls
-
Wrapping the instance instead of the class:
nn.checkpoint(MLP())doesn’t work; pass the CLASS, then instantiate. - Expecting a different forward output: it’s identical. The “checkpointing” only affects what’s saved across the forward/backward boundary.
-
Double-wrapping (
nn.checkpoint(nn.checkpoint(MLP))): legal but redundant — checkpointing twice doesn’t save more memory; it just adds layers of recompute.
Problem
Implement nn_checkpoint_forward(seed, x, features):
-
Define
MLP(nn.Module)withfeaturesfield. In@nn.compact:Dense(features) → relu → Dense(features). -
Wrap with
nn.checkpoint:CkptMLP = nn.checkpoint(MLP). -
Instantiate, init with
PRNGKey(seed)andx. - Apply, return flattened output.
The forward result is the same as a non-wrapped MLP — this is purely a structural / API correctness check.
Inputs:
-
seed: int. -
x: 1-D(D,). -
features: int F.
Output: 1-D (F,).
Hints
Sign in to attempt this problem and view the solution.