We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Checkpoint with Save Policy
Why this matters
The default jax.checkpoint is a blunt instrument: it recomputes every
operation inside the wrapped function during the backward pass. For a
function that mixes cheap elementwise ops (ReLU, sigmoid) with expensive
matrix multiplications, this is wasteful. A matmul that takes 95% of the
forward time should not be recomputed if it can be saved cheaply.
The policy= argument gives fine-grained control. A policy is a function
(op_type, *args, **kwargs) -> bool that returns True when an
intermediate value should be saved and False when it should be
recomputed. JAX ships several standard policies under
jax.checkpoint_policies.*:
| Policy | What it saves |
|---|---|
everything_saveable |
All intermediates (= no recompute) |
nothing_saveable |
Nothing (= full recompute, default) |
dots_with_no_batch_dims_saveable |
Matmul outputs only |
dots_saveable |
Matmuls including batched dims |
dots_with_no_batch_dims_saveable is the production default for MLP-like
networks: matmul outputs (expensive to recompute) are kept; elementwise
activations (cheap to recompute) are recomputed. Memory is reduced without
paying the full recompute cost.
Worked mini-example
import jax
import jax.numpy as jnp
def f(w, x):
y = x @ w # matmul โ expensive, SAVE
y = jax.nn.relu(y) # elementwise โ cheap, RECOMPUTE
return jnp.sum(y)
# Default: recompute everything
f_default = jax.checkpoint(f)
# Policy: save matmul outputs only
policy = jax.checkpoint_policies.dots_with_no_batch_dims_saveable
f_policy = jax.checkpoint(f, policy=policy)
x = jnp.ones((4, 4))
w = jnp.eye(4)
# Both produce the same gradient.
g1 = jax.grad(f_default, argnums=0)(w, x)
g2 = jax.grad(f_policy, argnums=0)(w, x)
assert jnp.allclose(g1, g2) # True
Common pitfalls
-
Policies live under
jax.checkpoint_policies.*โ they are not strings or ints; they are callables imported from that module. - The result is identical regardless of policy โ the policy controls memory layout, not the mathematical output. A test that checks the numeric result is equally valid for all policies.
-
nothing_saveableis the default โ callingjax.checkpoint(f)with no policy recomputes everything. -
Policy does NOT change the function signature โ
f_remat(w1, w2, x)takes the same arguments asf(w1, w2, x).
Problem
Implement checkpoint_dot_policy(w1, w2, x) that applies a 4-op pipeline:
y = x @ w1 โ relu(y) โ y @ w2 โ y^2 โ sum
Wrap the entire pipeline in jax.checkpoint using
policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable.
-
w1,w2: 2-D jax arrays(d, d). -
x: 2-D jax array(N, d).
Returns: scalar โ sum((relu(x @ w1) @ w2) ** 2).
Hints
Sign in to attempt this problem and view the solution.