medium primitives

Custom JVP: Clip with Pass-through Gradient

Why this matters

@jax.custom_jvp defines BOTH the primal output and the tangent (directional derivative) in a single rule, making it forward-mode-aware. This is the right tool for the Straight-Through Estimator (STE): let the forward pass be a discontinuous or non-differentiable op, but let the gradient flow as if it were the identity.

Standard jnp.clip has zero gradient outside [lo, hi] โ€” which is correct but kills training signals for inputs that are already clipped. The STE trick ignores the zero and passes the upstream gradient through unchanged, keeping the optimization landscape connected.

STE variants appear in:

  • Quantization-aware training (rounding with identity backward).
  • Binary/ternary networks (sign function with identity backward).
  • Hard attention (argmax forward, softmax backward).

Worked mini-example

import jax
import jax.numpy as jnp

@jax.custom_jvp
def clip_passthrough(x, lo, hi):
    return jnp.clip(x, lo, hi)

@clip_passthrough.defjvp
def _jvp(primals, tangents):
    x, lo, hi = primals
    x_dot, _, _ = tangents
    return jnp.clip(x, lo, hi), x_dot   # identity tangent

x = jnp.array([-2.0, 0.5, 3.0])
print(clip_passthrough(x, 0.0, 1.0))   # [0., 0.5, 1.]

# Gradient via jax.grad uses the JVP rule under the hood
g = jax.grad(lambda x: clip_passthrough(x, 0.0, 1.0).sum())(x)
# Standard clip: [0., 1., 0.]  (zero outside [lo, hi])
# Pass-through:  [1., 1., 1.]  (always 1 โ€” the STE)

Common pitfalls

  • @jax.custom_jvp vs @jax.custom_vjp โ€” use custom_jvp when you want to set the forward-mode tangent rule directly. It also implies the reverse-mode rule via transposition.
  • Registration order โ€” first @jax.custom_jvp on the function, then @<fn>.defjvp on the JVP rule. Using defvjp by mistake raises an error.
  • defjvp receives TUPLES โ€” primals and tangents are both tuples, one entry per input. Unpack all three: x, lo, hi = primals.
  • Must return BOTH โ€” (primal_out, tangent_out). Returning only one raises a shape error.
  • lo and hi tangents are typically ignored โ€” they are scalar constants; their tangents are _.

Problem

Implement clip_passthrough(x, lo, hi) that returns jnp.clip(x, lo, hi) in the forward pass and passes the upstream gradient through unchanged in the backward pass (Straight-Through Estimator).

  • x: 1-D jax array.
  • lo: scalar lower bound.
  • hi: scalar upper bound.

Returns: 1-D array same shape โ€” jnp.clip(x, lo, hi).

Hints

jax custom-jvp ste

Sign in to attempt this problem and view the solution.