hard primitives

NNX Remat Policies

Why this matters

Plain nnx.remat(f) is “all or nothing” — every intermediate inside f gets recomputed on the backward pass. That’s a sledgehammer. Sometimes the cheap-to-store, expensive-to-recompute things (like matmul outputs) are exactly what you’d PREFER to save, while the cheap-to-recompute things (like elementwise activations) are what you’d prefer to recompute.

jax.checkpoint_policies is a small library of named functions that answer the question “should we save THIS intermediate?” for each op encountered during tracing. Pass one to nnx.remat(policy=...) and you get fine-grained control: save the matmuls, recompute the activations.

A few common policies

  • dots_with_no_batch_dims_saveable — save the result of any matmul whose contraction has no batch dimensions. Useful when batch-dim matmuls dominate memory but non-batch dots (small ones) are cheap to keep around.
  • dots_saveable — save ALL matmul outputs. Recompute everything else.
  • nothing_saveable — recompute everything (equivalent to plain nnx.remat(f) with no policy).
  • everything_saveable — save everything (equivalent to no remat). Useful as a sanity-check baseline.
  • save_anything_except_these_names("attention_logits", ...) — save everything except specifically-named intermediates. Used in LM training to recompute attention logits (memory-hungry) while keeping LayerNorm outputs.

What changes when you set policy

Forward result: identical to plain remat (and to no remat). The policy only affects what backward sees. So this problem can verify correctness on the forward — the policy is exercised purely on the forward path, but the test doesn’t take a gradient.

The recipe

@nnx.remat(policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def fwd(model, x):
    return model(x)

Notice we call nnx.remat with parens — nnx.remat(policy=...) returns a decorator. That’s idiomatic for any “configurable decorator” in JAX/Flax.

Worked example: a tiny 2-layer MLP

class MLP(nnx.Module):
    def __init__(self, features, *, rngs):
        self.l1 = nnx.Linear(features, features, rngs=rngs)
        self.l2 = nnx.Linear(features, features, rngs=rngs)
    def __call__(self, x):
        x = jax.nn.relu(self.l1(x))
        x = jax.nn.relu(self.l2(x))
        return x

With dots_with_no_batch_dims_saveable, the backward pass would recompute the relu activations but the forward (x @ W) matmul outputs are saved.

Common pitfalls

  • Picking a policy that saves nothing AND has high recompute cost. You’ll get the worst of both worlds — slow AND not saving memory.
  • Forgetting that policy doesn’t affect the forward. If you want to debug “what’s saved”, inspect HLO or use jax.ad_checkpoint.print_saved_residuals rather than expecting the forward output to differ.
  • Mixing policy= with arbitrary lambdas. The policy must be a callable matching JAX’s expected signature. Stick to the named ones in jax.checkpoint_policies unless you really mean to write a custom one.

Problem

Implement nnx_remat_policy(seed, x, features):

  1. Define an MLP Module with two nnx.Linear(F, F) layers, each followed by relu.
  2. Build it with int(features) and nnx.Rngs(int(seed)).
  3. Wrap a forward function with @nnx.remat(policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable).
  4. Call fwd(model, x) and return out.reshape(-1).

The output is the same as un-rematted; the test verifies the policy keyword doesn’t break the forward.

Inputs:

  • seed: float (cast to int).
  • x: 1-D (F,).
  • features: float (cast to int) — F.

Output: 1-D (F,) — final activation.

Hints

flax nnx lifted-transforms remat checkpoint-policies

Sign in to attempt this problem and view the solution.