medium primitives

Gradient Checkpointing: Basics

Why this matters

Every layer in a neural network computes activations during the forward pass. By default, JAX (like PyTorch/TensorFlow) retains all of them so the backward pass can multiply cotangents through them. For a depth-D network this means O(D) activation memory โ€” the bottleneck when training large models on hardware with limited HBM.

jax.checkpoint (alias jax.remat) solves this with a recompute-on-demand strategy: activations from the wrapped function are discarded after the forward pass and recomputed during the backward pass. The gradient is bit-identical to the un-checkpointed version; only the memory profile changes. The trade-off: roughly 33% extra compute in exchange for constant (not linear) activation memory in the checkpointed segments.

Gradient checkpointing appears everywhere large models are trained:

  • Transformer training โ€” checkpoint every residual block to train longer sequences.
  • ResNet / EfficientNet โ€” checkpoint skip-connection blocks.
  • Recurrent networks โ€” checkpoint the per-step body (see scan + checkpoint).

Worked mini-example

import jax
import jax.numpy as jnp

# Without checkpoint โ€” activations of `middle` are retained in memory.
def middle(y, w):
    return y @ w

x  = jnp.ones((1, 4))
w1 = jnp.eye(4)
w2 = jnp.eye(4)
w3 = jnp.eye(4)

y1 = jax.nn.relu(x @ w1)
y2 = jax.nn.relu(middle(y1, w2))       # activations stored
y3 = y2 @ w3

# With checkpoint โ€” activations of `middle` are dropped and recomputed.
y2_ckpt = jax.nn.relu(jax.checkpoint(middle)(y1, w2))  # remat on bwd

# Forward values are identical.
assert jnp.allclose(y2, y2_ckpt)

Common pitfalls

  • jax.remat is an alias โ€” jax.remat and jax.checkpoint are the same function. Both names are correct; prefer jax.checkpoint for clarity.
  • Wrapped function MUST be pure โ€” jax.checkpoint traces the function twice (once forward, once backward). Any Python side-effect (print, counter, list-append) executes twice; impure ops break tracing.
  • Wrapping the whole pipeline defeats granularity โ€” the benefit comes from selective checkpointing at a sub-network boundary. Wrapping everything just increases total recompute without helping memory.
  • Forward result is numerically bit-identical โ€” the loss and gradients are unchanged; only the internal memory layout differs.

Problem

Implement checkpointed_layer_loss(w1, w2, w3, x) that applies three dense layers โ€” x @ w1, y1 @ w2, y2 @ w3 โ€” with ReLU after each matmul in the first two layers. The middle matmul (y1 @ w2) must be wrapped in jax.checkpoint. Return the sum of squared outputs.

  • w1, w2, w3: 2-D jax arrays (d, d).
  • x: 2-D jax array (N, d).

Returns: scalar โ€” sum(y3 ** 2).

Hints

jax checkpoint remat

Sign in to attempt this problem and view the solution.