We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Gradient Checkpointing: Basics
Why this matters
Every layer in a neural network computes activations during the forward pass. By default, JAX (like PyTorch/TensorFlow) retains all of them so the backward pass can multiply cotangents through them. For a depth-D network this means O(D) activation memory โ the bottleneck when training large models on hardware with limited HBM.
jax.checkpoint (alias jax.remat) solves this with a
recompute-on-demand strategy: activations from the wrapped function
are discarded after the forward pass and recomputed during the
backward pass. The gradient is bit-identical to the un-checkpointed
version; only the memory profile changes. The trade-off: roughly 33%
extra compute in exchange for constant (not linear) activation memory in
the checkpointed segments.
Gradient checkpointing appears everywhere large models are trained:
- Transformer training โ checkpoint every residual block to train longer sequences.
- ResNet / EfficientNet โ checkpoint skip-connection blocks.
- Recurrent networks โ checkpoint the per-step body (see scan + checkpoint).
Worked mini-example
import jax
import jax.numpy as jnp
# Without checkpoint โ activations of `middle` are retained in memory.
def middle(y, w):
return y @ w
x = jnp.ones((1, 4))
w1 = jnp.eye(4)
w2 = jnp.eye(4)
w3 = jnp.eye(4)
y1 = jax.nn.relu(x @ w1)
y2 = jax.nn.relu(middle(y1, w2)) # activations stored
y3 = y2 @ w3
# With checkpoint โ activations of `middle` are dropped and recomputed.
y2_ckpt = jax.nn.relu(jax.checkpoint(middle)(y1, w2)) # remat on bwd
# Forward values are identical.
assert jnp.allclose(y2, y2_ckpt)
Common pitfalls
-
jax.rematis an alias โjax.rematandjax.checkpointare the same function. Both names are correct; preferjax.checkpointfor clarity. -
Wrapped function MUST be pure โ
jax.checkpointtraces the function twice (once forward, once backward). Any Python side-effect (print, counter, list-append) executes twice; impure ops break tracing. - Wrapping the whole pipeline defeats granularity โ the benefit comes from selective checkpointing at a sub-network boundary. Wrapping everything just increases total recompute without helping memory.
- Forward result is numerically bit-identical โ the loss and gradients are unchanged; only the internal memory layout differs.
Problem
Implement checkpointed_layer_loss(w1, w2, w3, x) that applies three
dense layers โ x @ w1, y1 @ w2, y2 @ w3 โ with ReLU after each
matmul in the first two layers. The middle matmul (y1 @ w2) must be
wrapped in jax.checkpoint. Return the sum of squared outputs.
-
w1,w2,w3: 2-D jax arrays(d, d). -
x: 2-D jax array(N, d).
Returns: scalar โ sum(y3 ** 2).
Hints
Sign in to attempt this problem and view the solution.