We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Custom JVP: Clip with Pass-through Gradient
Why this matters
@jax.custom_jvp defines BOTH the primal output and the tangent (directional
derivative) in a single rule, making it forward-mode-aware. This is
the right tool for the Straight-Through Estimator (STE): let the forward
pass be a discontinuous or non-differentiable op, but let the gradient
flow as if it were the identity.
Standard jnp.clip has zero gradient outside [lo, hi] โ which is correct
but kills training signals for inputs that are already clipped. The STE trick
ignores the zero and passes the upstream gradient through unchanged, keeping
the optimization landscape connected.
STE variants appear in:
- Quantization-aware training (rounding with identity backward).
- Binary/ternary networks (sign function with identity backward).
- Hard attention (argmax forward, softmax backward).
Worked mini-example
import jax
import jax.numpy as jnp
@jax.custom_jvp
def clip_passthrough(x, lo, hi):
return jnp.clip(x, lo, hi)
@clip_passthrough.defjvp
def _jvp(primals, tangents):
x, lo, hi = primals
x_dot, _, _ = tangents
return jnp.clip(x, lo, hi), x_dot # identity tangent
x = jnp.array([-2.0, 0.5, 3.0])
print(clip_passthrough(x, 0.0, 1.0)) # [0., 0.5, 1.]
# Gradient via jax.grad uses the JVP rule under the hood
g = jax.grad(lambda x: clip_passthrough(x, 0.0, 1.0).sum())(x)
# Standard clip: [0., 1., 0.] (zero outside [lo, hi])
# Pass-through: [1., 1., 1.] (always 1 โ the STE)
Common pitfalls
-
@jax.custom_jvpvs@jax.custom_vjpโ usecustom_jvpwhen you want to set the forward-mode tangent rule directly. It also implies the reverse-mode rule via transposition. -
Registration order โ first
@jax.custom_jvpon the function, then@<fn>.defjvpon the JVP rule. Usingdefvjpby mistake raises an error. -
defjvpreceives TUPLES โprimalsandtangentsare both tuples, one entry per input. Unpack all three:x, lo, hi = primals. -
Must return BOTH โ
(primal_out, tangent_out). Returning only one raises a shape error. -
loandhitangents are typically ignored โ they are scalar constants; their tangents are_.
Problem
Implement clip_passthrough(x, lo, hi) that returns jnp.clip(x, lo, hi)
in the forward pass and passes the upstream gradient through unchanged in
the backward pass (Straight-Through Estimator).
-
x: 1-D jax array. -
lo: scalar lower bound. -
hi: scalar upper bound.
Returns: 1-D array same shape โ jnp.clip(x, lo, hi).
Hints
Sign in to attempt this problem and view the solution.