We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jacfwd vs jacrev
Why this matters
JAX provides two Jacobian functions built from opposite AD primitives:
-
jax.jacfwd— forward-mode (runsjvponce per input dimension). Cost ∝input_dim. Efficient wheninput_dim ≪ output_dim(tall Jacobian). -
jax.jacrev— reverse-mode (runsvjponce per output dimension). Cost ∝output_dim. Efficient whenoutput_dim ≪ input_dim(wide Jacobian, i.e., the classic loss-gradient case).
Both produce the same matrix; the choice affects only wall-clock time and
memory. For square Jacobians, either works. For the Hessian of a scalar
loss, jax.hessian composes jacfwd(jacrev(f)) by default — forward-over-
reverse to balance cost.
Understanding this trade-off is the first step toward writing efficient second-order optimizers and sensitivity analyses.
Worked mini-example
import jax, jax.numpy as jnp
def f(x):
return jnp.array([x[0] + x[1], x[0] * x[1]])
x = jnp.array([2.0, 3.0])
J_fwd = jax.jacfwd(f)(x)
# array([[1., 1.],
# [3., 2.]]) # df0/dx0=1, df0/dx1=1, df1/dx0=x1=3, df1/dx1=x0=2
J_rev = jax.jacrev(f)(x)
# array([[1., 1.],
# [3., 2.]]) # identical result, different computation path
Both return shape (output_dim, input_dim).
Common pitfalls
-
jax.jacfwd(f(x))— wrong! You wrap the function, not its result. Correct:jax.jacfwd(f)(x). -
Output shape convention: for a function
f: R^d → R^m, the Jacobian shape is(m, d)— output axes first, then input axes. -
Scalar outputs: if
freturns a scalar, usejax.gradinstead;jacfwd/jacrevstill work but give a 0-d output dimension. -
jax.hessianisjacfwd(jacrev(f))by default — no need to compose manually.
Problem
Implement jacobian_fwd(x) that computes the Jacobian of
f(x) = [sum(x), sum(x²), sum(x³)]
using jax.jacfwd. The function maps x ∈ R^d → R^3, so the
Jacobian has shape (3, d):
-
Row 0:
[∂/∂x₀ sum(x), …]=[1, 1, …, 1] -
Row 1:
[∂/∂x₀ sum(x²), …]=[2x₀, 2x₁, …] -
Row 2:
[∂/∂x₀ sum(x³), …]=[3x₀², 3x₁², …]
Hints
Sign in to attempt this problem and view the solution.