We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
vjp Basics
Why this matters
jax.vjp is JAX’s reverse-mode autodiff primitive. For a function
f: R^n → R^m and a cotangent vector g (same shape as the output),
vjp_fn(g) computes the vector-Jacobian product gᵀ · J_f(x) in a
single backward pass.
This is the engine behind jax.grad: for a scalar-output f,
jax.grad(f)(x) is exactly jax.vjp(f, x)[1](1.0)[0]. Knowing the
primitive lets you:
- Apply multiple cotangents without re-running the forward pass.
-
Implement custom reverse-mode rules via
jax.custom_vjp. - Build batched Jacobians row-by-row (next problem).
Reverse-mode is cheapest when output dimension ≪ input dimension — the classic loss-function scenario in ML.
Worked mini-example
import jax, jax.numpy as jnp
def f(x):
return jnp.sum(x ** 2)
x = jnp.array([1.0, 2.0, 3.0])
primal, vjp_fn = jax.vjp(f, x)
# primal → 14.0 (f(x) = 1 + 4 + 9)
grad = vjp_fn(1.0)[0]
# grad → [2., 4., 6.] (2*x with cotangent=1.0)
scaled = vjp_fn(2.0)[0]
# scaled → [4., 8., 12.] (2*x scaled by cotangent=2.0)
Common pitfalls
-
vjp_fnreturns a TUPLE — index[0]to get the gradient for the first (and only) input argument. -
Cotangent shape must match OUTPUT shape — for scalar
f, pass a Python float; for arrayf, pass an array of the same shape. -
jax.vjp(f(x), cotangent)is wrong — wrap the FUNCTION, not its application:jax.vjp(f, x). -
Reusing
vjp_fn— you can callvjp_fn(g)many times with different cotangents without re-running the forward pass.
Problem
Implement vjp_quadratic(x, cotangent) that applies jax.vjp to
f(x) = sum(x²) at x with the given cotangent.
-
x: 1-D jax array. -
cotangent: scalar (Python float).
Returns: 1-D array same shape as x — the vector-Jacobian product
cotangent · 2x.
For cotangent=1.0 this equals jax.grad(f)(x) = 2*x.
Use jax.vjp — do not compute the gradient by hand.
Hints
Sign in to attempt this problem and view the solution.