We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Remat Policies
Why this matters
Plain nnx.remat(f) is “all or nothing” — every intermediate inside
f gets recomputed on the backward pass. That’s a sledgehammer.
Sometimes the cheap-to-store, expensive-to-recompute things (like
matmul outputs) are exactly what you’d PREFER to save, while the
cheap-to-recompute things (like elementwise activations) are what
you’d prefer to recompute.
jax.checkpoint_policies is a small library of named functions that
answer the question “should we save THIS intermediate?” for each
op encountered during tracing. Pass one to nnx.remat(policy=...)
and you get fine-grained control: save the matmuls, recompute the
activations.
A few common policies
-
dots_with_no_batch_dims_saveable— save the result of any matmul whose contraction has no batch dimensions. Useful when batch-dim matmuls dominate memory but non-batch dots (small ones) are cheap to keep around. -
dots_saveable— save ALL matmul outputs. Recompute everything else. -
nothing_saveable— recompute everything (equivalent to plainnnx.remat(f)with no policy). -
everything_saveable— save everything (equivalent to no remat). Useful as a sanity-check baseline. -
save_anything_except_these_names("attention_logits", ...)— save everything except specifically-named intermediates. Used in LM training to recompute attention logits (memory-hungry) while keeping LayerNorm outputs.
What changes when you set policy
Forward result: identical to plain remat (and to no remat). The policy only affects what backward sees. So this problem can verify correctness on the forward — the policy is exercised purely on the forward path, but the test doesn’t take a gradient.
The recipe
@nnx.remat(policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def fwd(model, x):
return model(x)
Notice we call nnx.remat with parens — nnx.remat(policy=...)
returns a decorator. That’s idiomatic for any “configurable
decorator” in JAX/Flax.
Worked example: a tiny 2-layer MLP
class MLP(nnx.Module):
def __init__(self, features, *, rngs):
self.l1 = nnx.Linear(features, features, rngs=rngs)
self.l2 = nnx.Linear(features, features, rngs=rngs)
def __call__(self, x):
x = jax.nn.relu(self.l1(x))
x = jax.nn.relu(self.l2(x))
return x
With dots_with_no_batch_dims_saveable, the backward pass would
recompute the relu activations but the forward (x @ W) matmul
outputs are saved.
Common pitfalls
- Picking a policy that saves nothing AND has high recompute cost. You’ll get the worst of both worlds — slow AND not saving memory.
-
Forgetting that
policydoesn’t affect the forward. If you want to debug “what’s saved”, inspect HLO or usejax.ad_checkpoint.print_saved_residualsrather than expecting the forward output to differ. -
Mixing
policy=with arbitrary lambdas. The policy must be a callable matching JAX’s expected signature. Stick to the named ones injax.checkpoint_policiesunless you really mean to write a custom one.
Problem
Implement nnx_remat_policy(seed, x, features):
-
Define an
MLPModule with twonnx.Linear(F, F)layers, each followed byrelu. -
Build it with
int(features)andnnx.Rngs(int(seed)). -
Wrap a forward function with
@nnx.remat(policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable). -
Call
fwd(model, x)and returnout.reshape(-1).
The output is the same as un-rematted; the test verifies the policy keyword doesn’t break the forward.
Inputs:
-
seed: float (cast to int). -
x: 1-D(F,). -
features: float (cast to int) — F.
Output: 1-D (F,) — final activation.
Hints
Sign in to attempt this problem and view the solution.