medium primitives

Jacobian via Batched vjp

Why this matters

For a function f: R^d → R^k, the full Jacobian matrix has shape (k, d). You can recover it row-by-row using jax.vjp: each call vjp_fn(e_i) with a one-hot cotangent e_i ∈ R^k yields row i of the Jacobian.

This is exactly what jax.jacrev does internally. Understanding the loop demystifies:

  • Why reverse-mode costs k passes (one per output).
  • How to implement custom Jacobian accumulation for structured outputs.
  • Why vmap(vjp_fn)(eye) is the vectorized, JIT-friendly version.

Worked mini-example

import jax, jax.numpy as jnp

def f(x):
    return jnp.array([x[0] + x[1], x[0] * x[1]])

x = jnp.array([2.0, 3.0])
_, vjp_fn = jax.vjp(f, x)

eye = jnp.eye(2)
row0 = vjp_fn(eye[0])[0]   # → [1., 1.]   (∂(x0+x1)/∂x = [1,1])
row1 = vjp_fn(eye[1])[0]   # → [3., 2.]   (∂(x0*x1)/∂x = [x1,x0]=[3,2])

J = jnp.stack([row0, row1], axis=0)
# → array([[1., 1.],
#           [3., 2.]])

Common pitfalls

  • Cotangent shape must match output shapevjp_fn(eye[i]) works because eye[i] has the same shape as f(x) (length-3 vector).
  • vjp_fn returns a tuple — always index [0] for the single-arg gradient.
  • For-loop is pedagogically clear; vmap(vjp_fn)(eye) is faster in practice since it amortizes XLA compilation overhead.
  • jax.jacrev(f)(x) is equivalent — use it in production; the loop here is just to show the mechanics.

Problem

Implement jacobian_via_vjp(x) that computes the Jacobian of

f(x) = [sum(x), sum(2*x), sum(x²)]

by calling jax.vjp once and then iterating over the 3 standard basis vectors of R^3.

  • x: 1-D jax array of length d.

Returns: 2-D array of shape (3, d):

  • Row 0: [1, 1, …, 1] (gradient of sum(x))
  • Row 1: [2, 2, …, 2] (gradient of sum(2x))
  • Row 2: [2x₀, 2x₁, …, 2x_{d-1}] (gradient of sum(x²))

Hints

jax vjp jacobian

Sign in to attempt this problem and view the solution.