hard primitives

nn.remat with checkpoint policies

Why this matters

Plain nn.checkpoint(MLP) (pos 82) is all-or-nothing: it rematerializes EVERY intermediate inside the wrapped Module during backward. That saves a lot of memory but also burns a lot of compute — you’re recomputing big matmuls just to get their output activations again.

The next step up is selective rematerialization: tell JAX to save SOME intermediates (the cheap ones) and recompute the rest (the expensive ones). The argument that controls this is policy=....

A policy is a function (prim, *args, **params) -> Saveable | None that JAX consults for every primitive op in the trace. It returns a token saying “save this output across forward/backward” or None meaning “OK to recompute.”

The named policies

JAX ships several pre-baked policies in jax.checkpoint_policies:

  • nothing_saveable — recompute everything. Most aggressive.
  • everything_saveable — save everything. Equivalent to no checkpointing.
  • dots_saveable (alias checkpoint_dots) — save the output of every dot_general (matmul). Recompute element-wise ops (relu, layernorm fused into XLA) since they’re cheap.
  • dots_with_no_batch_dims_saveable — same idea but only the “small” matmuls (without a batch dim — typically the QKV/ output projections in attention). Doesn’t save the big attention-score matmul (which has a batch axis and is huge).
  • save_anything_except_these_names — save everything except named ops the user marks as recompute-only.
  • save_only_these_names — save only named ops.

The dots_with_no_batch_dims_saveable policy is the canonical “good default” for transformers: keep the cheap-to-store but expensive-to-recompute QKV projections, recompute the giant attention scores. Memory drops a lot; compute barely budges.

The canonical incantation

import jax
import flax.linen as nn

RematMLP = nn.remat(
    MLP,
    policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable,
)
model = RematMLP(features=F)
params = model.init(rng, x)
out = model.apply(params, x)

Forward output is identical. Policy only affects what’s saved across the forward/backward boundary.

Forward equivalence again

Same as pos 82: changing the policy doesn’t change forward numerics. So this problem just verifies you wrap correctly with a policy specified.

Why this lands in production

Real projects (T5X, MaxText, Pallas-style code) use these named policies to dial the memory/compute knob without writing a custom policy function. A common recipe for transformers:

Block = nn.remat(
    TransformerBlock,
    policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable,
)
Stack = nn.scan(Block, ..., length=num_layers)

Per-block remat with a sensible policy + scan over layers = the JAX-native way to train very deep transformers without OOM.

Common pitfalls

  • Wrong policy argument name: it’s policy=, not checkpoint_policy= or save_policy=.
  • Wrapping the instance: pass the CLASS, then construct.
  • Custom policy that always returns None: equivalent to nothing_saveable — slow and pointless if you also have cheap ops you’d happily save.

Problem

Implement nn_remat_policy(seed, x, features):

  1. Define MLP(nn.Module) with features field. Layers: Dense → relu → Dense → relu → Dense.

  2. Wrap with nn.remat and policy jax.checkpoint_policies.dots_with_no_batch_dims_saveable:

    RematMLP = nn.remat(MLP, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
  3. Init with PRNGKey(seed) and x.

  4. Apply.

  5. Return flattened output.

Forward result equals the un-wrapped MLP. The policy only matters under jax.grad.

Inputs:

  • seed: int.
  • x: 1-D (D,).
  • features: int F.

Output: 1-D (F,).

Hints

flax nn-remat checkpoint-policy lifted-transforms

Sign in to attempt this problem and view the solution.