hard primitives

Checkpoint with Save Policy

Why this matters

The default jax.checkpoint is a blunt instrument: it recomputes every operation inside the wrapped function during the backward pass. For a function that mixes cheap elementwise ops (ReLU, sigmoid) with expensive matrix multiplications, this is wasteful. A matmul that takes 95% of the forward time should not be recomputed if it can be saved cheaply.

The policy= argument gives fine-grained control. A policy is a function (op_type, *args, **kwargs) -> bool that returns True when an intermediate value should be saved and False when it should be recomputed. JAX ships several standard policies under jax.checkpoint_policies.*:

Policy What it saves
everything_saveable All intermediates (= no recompute)
nothing_saveable Nothing (= full recompute, default)
dots_with_no_batch_dims_saveable Matmul outputs only
dots_saveable Matmuls including batched dims

dots_with_no_batch_dims_saveable is the production default for MLP-like networks: matmul outputs (expensive to recompute) are kept; elementwise activations (cheap to recompute) are recomputed. Memory is reduced without paying the full recompute cost.

Worked mini-example

import jax
import jax.numpy as jnp

def f(w, x):
    y = x @ w       # matmul โ€” expensive, SAVE
    y = jax.nn.relu(y)   # elementwise โ€” cheap, RECOMPUTE
    return jnp.sum(y)

# Default: recompute everything
f_default = jax.checkpoint(f)

# Policy: save matmul outputs only
policy = jax.checkpoint_policies.dots_with_no_batch_dims_saveable
f_policy = jax.checkpoint(f, policy=policy)

x = jnp.ones((4, 4))
w = jnp.eye(4)

# Both produce the same gradient.
g1 = jax.grad(f_default, argnums=0)(w, x)
g2 = jax.grad(f_policy, argnums=0)(w, x)
assert jnp.allclose(g1, g2)   # True

Common pitfalls

  • Policies live under jax.checkpoint_policies.* โ€” they are not strings or ints; they are callables imported from that module.
  • The result is identical regardless of policy โ€” the policy controls memory layout, not the mathematical output. A test that checks the numeric result is equally valid for all policies.
  • nothing_saveable is the default โ€” calling jax.checkpoint(f) with no policy recomputes everything.
  • Policy does NOT change the function signature โ€” f_remat(w1, w2, x) takes the same arguments as f(w1, w2, x).

Problem

Implement checkpoint_dot_policy(w1, w2, x) that applies a 4-op pipeline:

y = x @ w1  โ†’  relu(y)  โ†’  y @ w2  โ†’  y^2  โ†’  sum

Wrap the entire pipeline in jax.checkpoint using policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable.

  • w1, w2: 2-D jax arrays (d, d).
  • x: 2-D jax array (N, d).

Returns: scalar โ€” sum((relu(x @ w1) @ w2) ** 2).

Hints

jax checkpoint policy

Sign in to attempt this problem and view the solution.