We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Checkpoint / Remat
Why this matters
Backprop’s standard recipe — store every intermediate activation during the forward, reuse them during backward — has a memory cost proportional to the depth of the network and the size of the intermediates. For a 70B-parameter LM at long context, those activations can dominate the total memory budget, leaving no room for the model itself.
Gradient checkpointing (a.k.a. rematerialization, remat,
jax.checkpoint) is the canonical fix. The bargain: at backward
time, instead of looking up cached activations, recompute them by
running the forward AGAIN. You save memory; you spend compute.
For most modern models this is a great trade — modern GPUs are
compute-rich and memory-poor.
nnx.remat is the lifted version of jax.checkpoint. It wraps a
function that takes nnx Modules; under the hood it does the
split/merge so the underlying jax.checkpoint sees pure pytrees.
What checkpointing changes
Forward result: identical. remat doesn’t change what the model
computes — only how the backward pass reconstructs intermediates.
Calling a rematted forward and an eager forward returns the same
array.
Backward result: also identical (modulo numeric float jitter at the noise level). Grads through a rematted function are mathematically the same as grads through the un-rematted function.
What changes is the activation memory + FLOP count under differentiation:
- Without remat: store activations during forward, reuse during backward. Memory: O(L · activation_size). FLOPs: 1 forward + 1 backward.
- With remat: don’t store activations. During backward, re-run the forward to reconstruct them. Memory: O(activation_size of innermost block). FLOPs: 2 forwards + 1 backward (i.e., ~33% more compute).
The recipe
class MLP(nnx.Module):
def __init__(self, num_layers, features, *, rngs):
self.layers = nnx.List([
nnx.Linear(features, features, rngs=rngs)
for _ in range(num_layers)
])
def __call__(self, x):
for layer in self.layers:
x = jax.nn.relu(layer(x))
return x
model = MLP(N, F, rngs=nnx.Rngs(seed))
@nnx.remat
def fwd(model, x):
return model(x)
out = fwd(model, x) # same numerics as model(x)
A note on nnx.List
nnx Modules are pytrees, and pytrees can’t have raw Python lists of
sub-Modules — the framework needs to know which attributes contain
state. Use nnx.List([...]) to wrap a list of sub-Modules; this
makes the list itself a pytree node. (Plain lists raise a clear
error pointing you to this fix.)
Common pitfalls
- Expecting the forward output to differ. It doesn’t. remat is a backward-pass optimization. To observe the difference, you’d have to look at peak memory or HLO.
- Wrapping the WHOLE training step in remat. That includes the optimizer step, which is mostly cheap; you’re paying remat’s compute cost on something that didn’t have a memory problem. Usually you remat individual blocks (per-layer), not the whole step.
-
Remat-of-remat. Nested rematerialization can pile up forward
counts (1, 2, 4, …). Use deliberately or use
policy=to control what gets saved.
Problem
Implement checkpoint_forward(seed, x, features, num_layers):
-
Build a small
MLPclass that holdsnnx.List([nnx.Linear(F, F, rngs=rngs) for _ in range(N)])and appliesrelu(linear(x))in sequence. -
Build the model with
int(num_layers)layers andint(features). -
Wrap a forward function with
@nnx.remat:@nnx.remat def fwd(model, x): return model(x) -
Call
fwd(model, x)and returnout.reshape(-1).
The test verifies that the rematted output equals the eager output — remat must not change forward numerics.
Inputs:
-
seed: float (cast to int). -
x: 1-D(F,). -
features: float (cast to int) — F. -
num_layers: float (cast to int) — N.
Output: 1-D (F,) — final activation.
Hints
Sign in to attempt this problem and view the solution.