We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
primitives
jvp with Structured Tangents
Why this matters
jax.jvp computes the forward-mode derivative — the Jacobian-vector
product (JVP) — simultaneously with the primal output. For multi-argument
functions, both primals and tangents are tuples matching the function’s
argument list.
This is crucial for:
- Directional derivatives in physics simulations.
- Efficient Hessian-vector products (forward-over-reverse).
- Tangent propagation through structured inputs (e.g., pytrees of params from multiple model components).
Worked mini-example
import jax
import jax.numpy as jnp
def f(a, b):
return jnp.sum(a * b)
primals = (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0]))
tangents = (jnp.array([1.0, 0.0]), jnp.array([0.0, 1.0]))
primal_out, tangent_out = jax.jvp(f, primals, tangents)
# primal_out = 11.0 (1*3 + 2*4)
# tangent_out = 1*0 + 3*1 + 0*4 + 2*1 = 5.0 (via chain rule)
Common pitfalls
- Tangent shapes must match primal shapes exactly: mismatched shapes raise a shape error at trace time.
-
Cannot use
Noneas a tangent injvp: provide a zero array (jnp.zeros_like(x)) to zero out a direction instead. -
Tuple wrapping is mandatory: even for a single argument,
primals and tangents must be tuples —
(x,)notx.
Problem
Implement jvp_split_tangent(x_first, x_second, v_first, v_second) that:
-
Defines
f(x_first, x_second) = sum(x_first ** 2 + 2 * x_second). -
Computes
primal_out, tangent_out = jax.jvp(f, (x_first, x_second), (v_first, v_second)). -
Returns
jnp.array([primal_out, tangent_out])— a 1-D array of shape(2,).
-
x_first,x_second: 1-D JAX arrays (same shape). -
v_first,v_second: tangent arrays (same shape as x_first, x_second).
Returns: 1-D array shape (2,) — [primal, tangent].
Analytic tangent: 2 * dot(x_first, v_first) + 2 * sum(v_second).
Hints
jax
jvp
pytree
structured-tangents
Sign in to attempt this problem and view the solution.