hard primitives

Jacobian via Mixed-Mode (jvp+vjp)

Why this matters

When computing the full Jacobian of f: R^d โ†’ R^k, the cost depends on which mode you choose:

  • jacfwd (forward mode) โ€” costs O(d) JVPs. Best when d โ‰ช k.
  • jacrev (reverse mode) โ€” costs O(k) VJPs. Best when k โ‰ช d.
  • Mixed mode โ€” for d โ‰ˆ k, combining both modes can be more efficient than either alone. For example, computing jacrev(jacfwd(f)) gives the Hessian for vector-output functions efficiently.

JAX provides both via jax.jacfwd and jax.jacrev. They produce identical Jacobian matrices โ€” the difference is purely computational cost.

This problem uses jacrev directly for a R^d โ†’ R^3 function. Understanding the mode selection is critical when building efficient optimizers, second-order methods, or structured Hessian approximations.

Worked mini-example

import jax
import jax.numpy as jnp

def f(x):
    return jnp.array([x[0], x[0] + x[1]])   # R^2 โ†’ R^2

x = jnp.array([1.0, 2.0])

J_fwd = jax.jacfwd(f)(x)   # shape (2, 2)
J_rev = jax.jacrev(f)(x)   # shape (2, 2) โ€” identical values

# Both โ†’ [[1., 0.], [1., 1.]]

Common pitfalls

  • jacfwd vs jacrev return the same matrix โ€” use profiling, not intuition, to pick the faster one in practice.
  • Output shape โ€” for f: R^d โ†’ R^k, the Jacobian is (k, d).
  • Inner function must return an array, not a scalar, for the Jacobian to be 2-D; a scalar output gives a 1-D gradient.

Problem

Implement jacobian_mixed(x) that computes the Jacobian of f(x) = [sum(x), sum(2ยทx), sum(xยฒ)] using jax.jacrev.

  • x: 1-D jax array of shape (d,).

Returns: 2-D array (3, d) โ€” the Jacobian matrix.

Row 0: all ones. Row 1: all twos. Row 2: 2ยทx.

Hints

jax jacobian mixed-mode

Sign in to attempt this problem and view the solution.