medium primitives

NNX Checkpoint / Remat

Why this matters

Backprop’s standard recipe — store every intermediate activation during the forward, reuse them during backward — has a memory cost proportional to the depth of the network and the size of the intermediates. For a 70B-parameter LM at long context, those activations can dominate the total memory budget, leaving no room for the model itself.

Gradient checkpointing (a.k.a. rematerialization, remat, jax.checkpoint) is the canonical fix. The bargain: at backward time, instead of looking up cached activations, recompute them by running the forward AGAIN. You save memory; you spend compute. For most modern models this is a great trade — modern GPUs are compute-rich and memory-poor.

nnx.remat is the lifted version of jax.checkpoint. It wraps a function that takes nnx Modules; under the hood it does the split/merge so the underlying jax.checkpoint sees pure pytrees.

What checkpointing changes

Forward result: identical. remat doesn’t change what the model computes — only how the backward pass reconstructs intermediates. Calling a rematted forward and an eager forward returns the same array.

Backward result: also identical (modulo numeric float jitter at the noise level). Grads through a rematted function are mathematically the same as grads through the un-rematted function.

What changes is the activation memory + FLOP count under differentiation:

  • Without remat: store activations during forward, reuse during backward. Memory: O(L · activation_size). FLOPs: 1 forward + 1 backward.
  • With remat: don’t store activations. During backward, re-run the forward to reconstruct them. Memory: O(activation_size of innermost block). FLOPs: 2 forwards + 1 backward (i.e., ~33% more compute).

The recipe

class MLP(nnx.Module):
    def __init__(self, num_layers, features, *, rngs):
        self.layers = nnx.List([
            nnx.Linear(features, features, rngs=rngs)
            for _ in range(num_layers)
        ])
    def __call__(self, x):
        for layer in self.layers:
            x = jax.nn.relu(layer(x))
        return x

model = MLP(N, F, rngs=nnx.Rngs(seed))

@nnx.remat
def fwd(model, x):
    return model(x)

out = fwd(model, x)   # same numerics as model(x)

A note on nnx.List

nnx Modules are pytrees, and pytrees can’t have raw Python lists of sub-Modules — the framework needs to know which attributes contain state. Use nnx.List([...]) to wrap a list of sub-Modules; this makes the list itself a pytree node. (Plain lists raise a clear error pointing you to this fix.)

Common pitfalls

  • Expecting the forward output to differ. It doesn’t. remat is a backward-pass optimization. To observe the difference, you’d have to look at peak memory or HLO.
  • Wrapping the WHOLE training step in remat. That includes the optimizer step, which is mostly cheap; you’re paying remat’s compute cost on something that didn’t have a memory problem. Usually you remat individual blocks (per-layer), not the whole step.
  • Remat-of-remat. Nested rematerialization can pile up forward counts (1, 2, 4, …). Use deliberately or use policy= to control what gets saved.

Problem

Implement checkpoint_forward(seed, x, features, num_layers):

  1. Build a small MLP class that holds nnx.List([nnx.Linear(F, F, rngs=rngs) for _ in range(N)]) and applies relu(linear(x)) in sequence.

  2. Build the model with int(num_layers) layers and int(features).

  3. Wrap a forward function with @nnx.remat:

    @nnx.remat
    def fwd(model, x):
        return model(x)
  4. Call fwd(model, x) and return out.reshape(-1).

The test verifies that the rematted output equals the eager output — remat must not change forward numerics.

Inputs:

  • seed: float (cast to int).
  • x: 1-D (F,).
  • features: float (cast to int) — F.
  • num_layers: float (cast to int) — N.

Output: 1-D (F,) — final activation.

Hints

flax nnx lifted-transforms remat checkpointing

Sign in to attempt this problem and view the solution.