easy primitives

jvp Basics

Why this matters

jax.jvp is JAX’s forward-mode autodiff primitive. Given a function f and a tangent vector v, it computes both f(x) (the primal) and Df(x) · v (the directional derivative) in a single forward pass — no second traversal needed.

Forward-mode is most efficient when the input dimension is much smaller than the output dimension (e.g., a neural network with one scalar input but many outputs). It computes one column of the Jacobian per call. The reverse-mode counterpart (jax.grad / jax.vjp) is cheaper when there are many inputs and few outputs (the classic loss-function scenario).

Understanding jvp unlocks higher-order techniques: Hessian-vector products, forward-over-reverse AD, and jax.jvp-based sensitivity analysis.

Worked mini-example

import jax, jax.numpy as jnp

def f(x):
    return x ** 2

primal, tangent = jax.jvp(f, (3.0,), (1.0,))
# primal  → 9.0   (f(3) = 3^2)
# tangent → 6.0   (f'(3) · 1 = 2·3·1)

For a vector-valued x:

def f(x):
    return jnp.sum(x ** 2)

primal, tangent = jax.jvp(f, (jnp.array([1.0, 2.0, 3.0]),), (jnp.array([1.0, 0.0, 0.0]),))
# primal  → 14.0          (1 + 4 + 9)
# tangent →  2.0          (2·1·1 + 2·2·0 + 2·3·0 = 2)

Common pitfalls

  • jvp returns a tuple (primal, tangent) — not a single value. Always unpack it: primal, tangent = jax.jvp(...).
  • Inputs must be tuples: jax.jvp(f, x, v) raises an error — you must write jax.jvp(f, (x,), (v,)) even for a single argument.
  • Tangent shape must match input shape: the tangent v must have the same shape and dtype as x. A shape mismatch raises a JAX error.
  • Don’t confuse jvp with grad: jax.grad only works on scalar-valued functions; jax.jvp handles any output shape.

Problem

Implement jvp_quadratic(x, v) that computes the primal and tangent of f(x) = sum(x²) at point x along tangent direction v.

Return a 1-D array of shape (2,):

  • result[0] = primal value = sum(x²)
  • result[1] = tangent = 2·x · v (dot product)

Use jax.jvp — do not compute the derivative by hand.

Hints

jax jvp forward-mode

Sign in to attempt this problem and view the solution.