hard primitives

jvp with Structured Tangents

Why this matters

jax.jvp computes the forward-mode derivative — the Jacobian-vector product (JVP) — simultaneously with the primal output. For multi-argument functions, both primals and tangents are tuples matching the function’s argument list.

This is crucial for:

  • Directional derivatives in physics simulations.
  • Efficient Hessian-vector products (forward-over-reverse).
  • Tangent propagation through structured inputs (e.g., pytrees of params from multiple model components).

Worked mini-example

import jax
import jax.numpy as jnp

def f(a, b):
    return jnp.sum(a * b)

primals = (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0]))
tangents = (jnp.array([1.0, 0.0]), jnp.array([0.0, 1.0]))
primal_out, tangent_out = jax.jvp(f, primals, tangents)
# primal_out = 11.0  (1*3 + 2*4)
# tangent_out = 1*0 + 3*1 + 0*4 + 2*1 = 5.0   (via chain rule)

Common pitfalls

  • Tangent shapes must match primal shapes exactly: mismatched shapes raise a shape error at trace time.
  • Cannot use None as a tangent in jvp: provide a zero array (jnp.zeros_like(x)) to zero out a direction instead.
  • Tuple wrapping is mandatory: even for a single argument, primals and tangents must be tuples — (x,) not x.

Problem

Implement jvp_split_tangent(x_first, x_second, v_first, v_second) that:

  1. Defines f(x_first, x_second) = sum(x_first ** 2 + 2 * x_second).
  2. Computes primal_out, tangent_out = jax.jvp(f, (x_first, x_second), (v_first, v_second)).
  3. Returns jnp.array([primal_out, tangent_out]) — a 1-D array of shape (2,).
  • x_first, x_second: 1-D JAX arrays (same shape).
  • v_first, v_second: tangent arrays (same shape as x_first, x_second).

Returns: 1-D array shape (2,)[primal, tangent].

Analytic tangent: 2 * dot(x_first, v_first) + 2 * sum(v_second).

Hints

jax jvp pytree structured-tangents

Sign in to attempt this problem and view the solution.