We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
grad vs vjp
Why this matters
jax.grad(f)(x) is just sugar for jax.vjp(f, x)[1](1.0)[0] when f
has scalar output. Knowing this equivalence:
-
Demystifies what
gradactually computes under the hood. - Lets you switch between grad and vjp when you need more control (e.g., non-unit cotangents, multi-output functions).
-
Prepares you for custom_vjp, where you define the backward pass
directly in terms of
vjp_fn.
The cotangent 1.0 is the βseedβ for reverse-mode: it tells the backward
pass βthe output contributes with weight 1β, which recovers the plain
gradient. Any other scalar cotangent scales the gradient by that factor.
Worked mini-example
import jax, jax.numpy as jnp
x = jnp.array([1.0, 2.0])
b = 0.0
def loss(w):
return jnp.sum((w * x + b) ** 2)
# Via jax.grad:
g_grad = jax.grad(loss)(1.0)
# β 10.0 (2*(1*1+2*2)*1 = 2*(1+4) = 10)
# Via jax.vjp:
_, vjp_fn = jax.vjp(loss, 1.0)
g_vjp = vjp_fn(1.0)[0]
# β 10.0 (identical)
Common pitfalls
-
Cotangent must match output shape β for a scalar-output function,
pass a Python float
1.0, not an array. -
vjp_fnreturns a tuple β index[0]even for a single-argument function. -
wis a scalar here βjax.vjp(loss, w)treats the scalar as a 0-d array; the returned gradient has the same type/shape. -
Donβt pass
jnp.array(1.0)when1.0suffices β JAX accepts a plain Python float as the cotangent for scalar outputs.
Problem
Implement grad_via_vjp(w, b, x) that computes d/dw [sum((w*x + b)Β²)]
using jax.vjp instead of jax.grad.
-
w: Python float β the weight scalar. -
b: Python float β the bias scalar. -
x: 1-D jax array.
Returns: scalar β the gradient of the loss w.r.t. w.
The result must match jax.grad(loss)(w) exactly.
Hints
jax
grad
vjp
Sign in to attempt this problem and view the solution.