We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Jacobian via Mixed-Mode (jvp+vjp)
Why this matters
When computing the full Jacobian of f: R^d โ R^k, the cost depends on
which mode you choose:
-
jacfwd(forward mode) โ costsO(d)JVPs. Best whend โช k. -
jacrev(reverse mode) โ costsO(k)VJPs. Best whenk โช d. -
Mixed mode โ for
d โ k, combining both modes can be more efficient than either alone. For example, computingjacrev(jacfwd(f))gives the Hessian for vector-output functions efficiently.
JAX provides both via jax.jacfwd and jax.jacrev. They produce identical
Jacobian matrices โ the difference is purely computational cost.
This problem uses jacrev directly for a R^d โ R^3 function. Understanding
the mode selection is critical when building efficient optimizers, second-order
methods, or structured Hessian approximations.
Worked mini-example
import jax
import jax.numpy as jnp
def f(x):
return jnp.array([x[0], x[0] + x[1]]) # R^2 โ R^2
x = jnp.array([1.0, 2.0])
J_fwd = jax.jacfwd(f)(x) # shape (2, 2)
J_rev = jax.jacrev(f)(x) # shape (2, 2) โ identical values
# Both โ [[1., 0.], [1., 1.]]
Common pitfalls
-
jacfwdvsjacrevreturn the same matrix โ use profiling, not intuition, to pick the faster one in practice. -
Output shape โ for
f: R^d โ R^k, the Jacobian is(k, d). - Inner function must return an array, not a scalar, for the Jacobian to be 2-D; a scalar output gives a 1-D gradient.
Problem
Implement jacobian_mixed(x) that computes the Jacobian of
f(x) = [sum(x), sum(2ยทx), sum(xยฒ)] using jax.jacrev.
-
x: 1-D jax array of shape(d,).
Returns: 2-D array (3, d) โ the Jacobian matrix.
Row 0: all ones. Row 1: all twos. Row 2: 2ยทx.
Hints
Sign in to attempt this problem and view the solution.