We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Saved Residuals in custom_vjp
Why this matters
In a custom VJP, the backward pass often needs values computed during the
forward pass. The correct pattern is to save them as residuals — return
them from _fwd alongside the primal output, and unpack them in _bwd.
This avoids recomputation in the backward pass, which would be both wasteful and potentially incorrect (e.g., if the forward op is stochastic or has side effects).
Saving multiple residuals is simply a matter of returning a tuple:
(primal_output, (val1, val2, ...)).
Real-world uses:
- Fused activation functions — save the pre-activation for the backward (e.g., in custom ReLU, sigmoid, or gating operations).
- Normalization layers — save the running mean and variance.
- Custom attention — save the attention weights and key/query norms.
Worked mini-example
import jax
import jax.numpy as jnp
@jax.custom_vjp
def scaled(x, scale):
return x * scale
def fwd(x, scale):
return scaled(x, scale), (x, scale) # save both
def bwd(residuals, g):
x, scale = residuals
return (g * scale, jnp.sum(g * x)) # grad w.r.t. x and scale
scaled.defvjp(fwd, bwd)
Common pitfalls
-
Residual is always a single object — if you need multiple values,
wrap them in a tuple:
(val1, val2). Returning two separate residuals will error. - Memory cost — saved residuals persist in memory until the backward pass completes. For large intermediates, this matters.
-
Recomputing in
_bwddefeats the purpose — if you recomputesin(x)in the backward, you’re not saving time and may get subtly different floating-point results.
Problem
Implement custom_vjp_with_extra_residual(x) that:
-
Defines
f_with_sin(x) = sum(x * sin(x))with@jax.custom_vjp. -
In
_fwd, computes and savessin_x = sin(x)alongsidexas residuals. -
In
_bwd, uses savedxandsin_xto computecotangent * (sin(x) + x * cos(x)). -
Returns
jax.grad(f_with_sin)(x).
-
x: 1-D jax array.
Returns: 1-D array same shape — gradient of sum(x·sin(x)) w.r.t. x.
Hints
Sign in to attempt this problem and view the solution.