medium primitives

grad vs vjp

Why this matters

jax.grad(f)(x) is just sugar for jax.vjp(f, x)[1](1.0)[0] when f has scalar output. Knowing this equivalence:

  • Demystifies what grad actually computes under the hood.
  • Lets you switch between grad and vjp when you need more control (e.g., non-unit cotangents, multi-output functions).
  • Prepares you for custom_vjp, where you define the backward pass directly in terms of vjp_fn.

The cotangent 1.0 is the β€œseed” for reverse-mode: it tells the backward pass β€œthe output contributes with weight 1”, which recovers the plain gradient. Any other scalar cotangent scales the gradient by that factor.

Worked mini-example

import jax, jax.numpy as jnp

x = jnp.array([1.0, 2.0])
b = 0.0

def loss(w):
    return jnp.sum((w * x + b) ** 2)

# Via jax.grad:
g_grad = jax.grad(loss)(1.0)
# β†’ 10.0   (2*(1*1+2*2)*1 = 2*(1+4) = 10)

# Via jax.vjp:
_, vjp_fn = jax.vjp(loss, 1.0)
g_vjp = vjp_fn(1.0)[0]
# β†’ 10.0   (identical)

Common pitfalls

  • Cotangent must match output shape β€” for a scalar-output function, pass a Python float 1.0, not an array.
  • vjp_fn returns a tuple β€” index [0] even for a single-argument function.
  • w is a scalar here β€” jax.vjp(loss, w) treats the scalar as a 0-d array; the returned gradient has the same type/shape.
  • Don’t pass jnp.array(1.0) when 1.0 suffices β€” JAX accepts a plain Python float as the cotangent for scalar outputs.

Problem

Implement grad_via_vjp(w, b, x) that computes d/dw [sum((w*x + b)Β²)] using jax.vjp instead of jax.grad.

  • w: Python float β€” the weight scalar.
  • b: Python float β€” the bias scalar.
  • x: 1-D jax array.

Returns: scalar β€” the gradient of the loss w.r.t. w.

The result must match jax.grad(loss)(w) exactly.

Hints

jax grad vjp

Sign in to attempt this problem and view the solution.