We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
Jacobian via Batched vjp
Why this matters
For a function f: R^d → R^k, the full Jacobian matrix has shape (k, d).
You can recover it row-by-row using jax.vjp: each call vjp_fn(e_i)
with a one-hot cotangent e_i ∈ R^k yields row i of the Jacobian.
This is exactly what jax.jacrev does internally. Understanding the
loop demystifies:
-
Why reverse-mode costs
kpasses (one per output). - How to implement custom Jacobian accumulation for structured outputs.
-
Why
vmap(vjp_fn)(eye)is the vectorized, JIT-friendly version.
Worked mini-example
import jax, jax.numpy as jnp
def f(x):
return jnp.array([x[0] + x[1], x[0] * x[1]])
x = jnp.array([2.0, 3.0])
_, vjp_fn = jax.vjp(f, x)
eye = jnp.eye(2)
row0 = vjp_fn(eye[0])[0] # → [1., 1.] (∂(x0+x1)/∂x = [1,1])
row1 = vjp_fn(eye[1])[0] # → [3., 2.] (∂(x0*x1)/∂x = [x1,x0]=[3,2])
J = jnp.stack([row0, row1], axis=0)
# → array([[1., 1.],
# [3., 2.]])
Common pitfalls
-
Cotangent shape must match output shape —
vjp_fn(eye[i])works becauseeye[i]has the same shape asf(x)(length-3 vector). -
vjp_fnreturns a tuple — always index[0]for the single-arg gradient. -
For-loop is pedagogically clear;
vmap(vjp_fn)(eye)is faster in practice since it amortizes XLA compilation overhead. -
jax.jacrev(f)(x)is equivalent — use it in production; the loop here is just to show the mechanics.
Problem
Implement jacobian_via_vjp(x) that computes the Jacobian of
f(x) = [sum(x), sum(2*x), sum(x²)]
by calling jax.vjp once and then iterating over the 3 standard basis
vectors of R^3.
-
x: 1-D jax array of lengthd.
Returns: 2-D array of shape (3, d):
-
Row 0:
[1, 1, …, 1](gradient ofsum(x)) -
Row 1:
[2, 2, …, 2](gradient ofsum(2x)) -
Row 2:
[2x₀, 2x₁, …, 2x_{d-1}](gradient ofsum(x²))
Hints
jax
vjp
jacobian
Sign in to attempt this problem and view the solution.