medium primitives

jacfwd vs jacrev

Why this matters

JAX provides two Jacobian functions built from opposite AD primitives:

  • jax.jacfwd — forward-mode (runs jvp once per input dimension). Cost ∝ input_dim. Efficient when input_dim ≪ output_dim (tall Jacobian).
  • jax.jacrev — reverse-mode (runs vjp once per output dimension). Cost ∝ output_dim. Efficient when output_dim ≪ input_dim (wide Jacobian, i.e., the classic loss-gradient case).

Both produce the same matrix; the choice affects only wall-clock time and memory. For square Jacobians, either works. For the Hessian of a scalar loss, jax.hessian composes jacfwd(jacrev(f)) by default — forward-over- reverse to balance cost.

Understanding this trade-off is the first step toward writing efficient second-order optimizers and sensitivity analyses.

Worked mini-example

import jax, jax.numpy as jnp

def f(x):
    return jnp.array([x[0] + x[1], x[0] * x[1]])

x = jnp.array([2.0, 3.0])

J_fwd = jax.jacfwd(f)(x)
# array([[1., 1.],
#        [3., 2.]])    # df0/dx0=1, df0/dx1=1, df1/dx0=x1=3, df1/dx1=x0=2

J_rev = jax.jacrev(f)(x)
# array([[1., 1.],
#        [3., 2.]])    # identical result, different computation path

Both return shape (output_dim, input_dim).

Common pitfalls

  • jax.jacfwd(f(x)) — wrong! You wrap the function, not its result. Correct: jax.jacfwd(f)(x).
  • Output shape convention: for a function f: R^d → R^m, the Jacobian shape is (m, d) — output axes first, then input axes.
  • Scalar outputs: if f returns a scalar, use jax.grad instead; jacfwd/jacrev still work but give a 0-d output dimension.
  • jax.hessian is jacfwd(jacrev(f)) by default — no need to compose manually.

Problem

Implement jacobian_fwd(x) that computes the Jacobian of

f(x) = [sum(x), sum(x²), sum(x³)]

using jax.jacfwd. The function maps x ∈ R^d → R^3, so the Jacobian has shape (3, d):

  • Row 0: [∂/∂x₀ sum(x), …] = [1, 1, …, 1]
  • Row 1: [∂/∂x₀ sum(x²), …] = [2x₀, 2x₁, …]
  • Row 2: [∂/∂x₀ sum(x³), …] = [3x₀², 3x₁², …]

Hints

jax jacfwd jacrev jacobian

Sign in to attempt this problem and view the solution.