medium primitives

Saved Residuals in custom_vjp

Why this matters

In a custom VJP, the backward pass often needs values computed during the forward pass. The correct pattern is to save them as residuals — return them from _fwd alongside the primal output, and unpack them in _bwd.

This avoids recomputation in the backward pass, which would be both wasteful and potentially incorrect (e.g., if the forward op is stochastic or has side effects).

Saving multiple residuals is simply a matter of returning a tuple: (primal_output, (val1, val2, ...)).

Real-world uses:

  • Fused activation functions — save the pre-activation for the backward (e.g., in custom ReLU, sigmoid, or gating operations).
  • Normalization layers — save the running mean and variance.
  • Custom attention — save the attention weights and key/query norms.

Worked mini-example

import jax
import jax.numpy as jnp

@jax.custom_vjp
def scaled(x, scale):
    return x * scale

def fwd(x, scale):
    return scaled(x, scale), (x, scale)   # save both

def bwd(residuals, g):
    x, scale = residuals
    return (g * scale, jnp.sum(g * x))    # grad w.r.t. x and scale

scaled.defvjp(fwd, bwd)

Common pitfalls

  • Residual is always a single object — if you need multiple values, wrap them in a tuple: (val1, val2). Returning two separate residuals will error.
  • Memory cost — saved residuals persist in memory until the backward pass completes. For large intermediates, this matters.
  • Recomputing in _bwd defeats the purpose — if you recompute sin(x) in the backward, you’re not saving time and may get subtly different floating-point results.

Problem

Implement custom_vjp_with_extra_residual(x) that:

  1. Defines f_with_sin(x) = sum(x * sin(x)) with @jax.custom_vjp.
  2. In _fwd, computes and saves sin_x = sin(x) alongside x as residuals.
  3. In _bwd, uses saved x and sin_x to compute cotangent * (sin(x) + x * cos(x)).
  4. Returns jax.grad(f_with_sin)(x).
  • x: 1-D jax array.

Returns: 1-D array same shape — gradient of sum(x·sin(x)) w.r.t. x.

Hints

jax custom-vjp residuals

Sign in to attempt this problem and view the solution.