medium primitives

Custom VJP: Stable log1pexp

Why this matters

@jax.custom_vjp lets you override the backward pass without changing the forward computation. This is the go-to tool for numerical stability: write an equivalent gradient that avoids overflow at the boundary.

The classic example is log(1 + exp(x)) (the softplus activation). Its automatic gradient is exp(x) / (1 + exp(x)), which overflows for large x because the numerator and denominator both exceed float32 range. The mathematically equivalent form sigmoid(x) = 1 / (1 + exp(-x)) is numerically stable for all x โ€” and that is what the custom VJP provides.

This pattern appears everywhere:

  • Stable log-sum-exp for attention and CTC.
  • Clamped gradients in reinforcement learning.
  • Implicit function theorem gradients (see the next problem).

Worked mini-example

import jax
import jax.numpy as jnp

@jax.custom_vjp
def f(x):
    return jnp.log1p(jnp.exp(x))

def fwd(x):
    return f(x), x          # return (primal, residuals)

def bwd(x, g):
    return (g * jax.nn.sigmoid(x),)   # tuple, one entry per input

f.defvjp(fwd, bwd)

x = jnp.array([0.0, 1.0])
print(f(x))                   # [0.693, 1.313]
print(jax.grad(lambda x: f(x).sum())(x))  # sigmoid([0, 1]) = [0.5, 0.731]

Without the custom VJP, jax.grad would differentiate through log1p(exp(x)) automatically โ€” but for x = 100 the intermediate exp(100) overflows to inf, giving inf/inf = nan in the gradient.

Common pitfalls

  • Wrong decorator โ€” @jax.custom_jvp is different from @jax.custom_vjp. The vjp variant overrides the reverse-mode rule.
  • _fwd signature โ€” must return (primal_output, residuals) as a 2-tuple. The residuals are forwarded verbatim to _bwd.
  • _bwd returns a TUPLE โ€” one element per input to the original function. The trailing comma is required: (grad,) not (grad).
  • Both _fwd and _bwd must be JAX-traceable โ€” no Python side-effects, no conditionals on traced values.

Problem

Implement stable_log1pexp(x) that computes log(1 + exp(x)) using @jax.custom_vjp, with the backward pass using sigmoid(x) instead of the automatic derivative.

  • x: 1-D jax array.

Returns: 1-D array same shape โ€” log(1 + exp(x)).

Hints

jax custom-vjp numerical-stability

Sign in to attempt this problem and view the solution.