We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Custom VJP: Stable log1pexp
Why this matters
@jax.custom_vjp lets you override the backward pass without changing
the forward computation. This is the go-to tool for numerical stability:
write an equivalent gradient that avoids overflow at the boundary.
The classic example is log(1 + exp(x)) (the softplus activation). Its
automatic gradient is exp(x) / (1 + exp(x)), which overflows for large
x because the numerator and denominator both exceed float32 range.
The mathematically equivalent form sigmoid(x) = 1 / (1 + exp(-x)) is
numerically stable for all x โ and that is what the custom VJP provides.
This pattern appears everywhere:
- Stable log-sum-exp for attention and CTC.
- Clamped gradients in reinforcement learning.
- Implicit function theorem gradients (see the next problem).
Worked mini-example
import jax
import jax.numpy as jnp
@jax.custom_vjp
def f(x):
return jnp.log1p(jnp.exp(x))
def fwd(x):
return f(x), x # return (primal, residuals)
def bwd(x, g):
return (g * jax.nn.sigmoid(x),) # tuple, one entry per input
f.defvjp(fwd, bwd)
x = jnp.array([0.0, 1.0])
print(f(x)) # [0.693, 1.313]
print(jax.grad(lambda x: f(x).sum())(x)) # sigmoid([0, 1]) = [0.5, 0.731]
Without the custom VJP, jax.grad would differentiate through
log1p(exp(x)) automatically โ but for x = 100 the intermediate
exp(100) overflows to inf, giving inf/inf = nan in the gradient.
Common pitfalls
-
Wrong decorator โ
@jax.custom_jvpis different from@jax.custom_vjp. Thevjpvariant overrides the reverse-mode rule. -
_fwdsignature โ must return(primal_output, residuals)as a 2-tuple. The residuals are forwarded verbatim to_bwd. -
_bwdreturns a TUPLE โ one element per input to the original function. The trailing comma is required:(grad,)not(grad). -
Both
_fwdand_bwdmust be JAX-traceable โ no Python side-effects, no conditionals on traced values.
Problem
Implement stable_log1pexp(x) that computes log(1 + exp(x)) using
@jax.custom_vjp, with the backward pass using sigmoid(x) instead of
the automatic derivative.
-
x: 1-D jax array.
Returns: 1-D array same shape โ log(1 + exp(x)).
Hints
Sign in to attempt this problem and view the solution.