We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jvp Basics
Why this matters
jax.jvp is JAX’s forward-mode autodiff primitive. Given a function
f and a tangent vector v, it computes both f(x) (the primal) and
Df(x) · v (the directional derivative) in a single forward pass — no
second traversal needed.
Forward-mode is most efficient when the input dimension is much smaller
than the output dimension (e.g., a neural network with one scalar input
but many outputs). It computes one column of the Jacobian per call. The
reverse-mode counterpart (jax.grad / jax.vjp) is cheaper when there
are many inputs and few outputs (the classic loss-function scenario).
Understanding jvp unlocks higher-order techniques: Hessian-vector
products, forward-over-reverse AD, and jax.jvp-based sensitivity
analysis.
Worked mini-example
import jax, jax.numpy as jnp
def f(x):
return x ** 2
primal, tangent = jax.jvp(f, (3.0,), (1.0,))
# primal → 9.0 (f(3) = 3^2)
# tangent → 6.0 (f'(3) · 1 = 2·3·1)
For a vector-valued x:
def f(x):
return jnp.sum(x ** 2)
primal, tangent = jax.jvp(f, (jnp.array([1.0, 2.0, 3.0]),), (jnp.array([1.0, 0.0, 0.0]),))
# primal → 14.0 (1 + 4 + 9)
# tangent → 2.0 (2·1·1 + 2·2·0 + 2·3·0 = 2)
Common pitfalls
-
jvpreturns a tuple(primal, tangent)— not a single value. Always unpack it:primal, tangent = jax.jvp(...). -
Inputs must be tuples:
jax.jvp(f, x, v)raises an error — you must writejax.jvp(f, (x,), (v,))even for a single argument. -
Tangent shape must match input shape: the tangent
vmust have the same shape and dtype asx. A shape mismatch raises a JAX error. -
Don’t confuse
jvpwithgrad:jax.gradonly works on scalar-valued functions;jax.jvphandles any output shape.
Problem
Implement jvp_quadratic(x, v) that computes the primal and
tangent of f(x) = sum(x²) at point x along tangent direction v.
Return a 1-D array of shape (2,):
-
result[0]= primal value =sum(x²) -
result[1]= tangent =2·x · v(dot product)
Use jax.jvp — do not compute the derivative by hand.
Hints
Sign in to attempt this problem and view the solution.