We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
nn.remat with checkpoint policies
Why this matters
Plain nn.checkpoint(MLP) (pos 82) is all-or-nothing: it
rematerializes EVERY intermediate inside the wrapped Module
during backward. That saves a lot of memory but also burns a
lot of compute — you’re recomputing big matmuls just to get
their output activations again.
The next step up is selective rematerialization: tell JAX
to save SOME intermediates (the cheap ones) and recompute the
rest (the expensive ones). The argument that controls this is
policy=....
A policy is a function (prim, *args, **params) -> Saveable | None
that JAX consults for every primitive op in the trace. It returns
a token saying “save this output across forward/backward” or
None meaning “OK to recompute.”
The named policies
JAX ships several pre-baked policies in jax.checkpoint_policies:
-
nothing_saveable— recompute everything. Most aggressive. -
everything_saveable— save everything. Equivalent to no checkpointing. -
dots_saveable(aliascheckpoint_dots) — save the output of everydot_general(matmul). Recompute element-wise ops (relu, layernorm fused into XLA) since they’re cheap. -
dots_with_no_batch_dims_saveable— same idea but only the “small” matmuls (without a batch dim — typically the QKV/ output projections in attention). Doesn’t save the big attention-score matmul (which has a batch axis and is huge). -
save_anything_except_these_names— save everything except named ops the user marks as recompute-only. -
save_only_these_names— save only named ops.
The dots_with_no_batch_dims_saveable policy is the canonical
“good default” for transformers: keep the cheap-to-store but
expensive-to-recompute QKV projections, recompute the giant
attention scores. Memory drops a lot; compute barely budges.
The canonical incantation
import jax
import flax.linen as nn
RematMLP = nn.remat(
MLP,
policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable,
)
model = RematMLP(features=F)
params = model.init(rng, x)
out = model.apply(params, x)
Forward output is identical. Policy only affects what’s saved across the forward/backward boundary.
Forward equivalence again
Same as pos 82: changing the policy doesn’t change forward numerics. So this problem just verifies you wrap correctly with a policy specified.
Why this lands in production
Real projects (T5X, MaxText, Pallas-style code) use these named policies to dial the memory/compute knob without writing a custom policy function. A common recipe for transformers:
Block = nn.remat(
TransformerBlock,
policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable,
)
Stack = nn.scan(Block, ..., length=num_layers)
Per-block remat with a sensible policy + scan over layers = the JAX-native way to train very deep transformers without OOM.
Common pitfalls
-
Wrong
policyargument name: it’spolicy=, notcheckpoint_policy=orsave_policy=. - Wrapping the instance: pass the CLASS, then construct.
-
Custom policy that always returns None: equivalent to
nothing_saveable— slow and pointless if you also have cheap ops you’d happily save.
Problem
Implement nn_remat_policy(seed, x, features):
-
Define
MLP(nn.Module)withfeaturesfield. Layers:Dense → relu → Dense → relu → Dense. -
Wrap with
nn.rematand policyjax.checkpoint_policies.dots_with_no_batch_dims_saveable:RematMLP = nn.remat(MLP, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable) -
Init with
PRNGKey(seed)andx. -
Apply.
-
Return flattened output.
Forward result equals the un-wrapped MLP. The policy only
matters under jax.grad.
Inputs:
-
seed: int. -
x: 1-D(D,). -
features: int F.
Output: 1-D (F,).
Hints
Sign in to attempt this problem and view the solution.