medium primitives

jvp for Sensitivity Analysis

Why this matters

When tuning a model, you often ask: if I perturb parameter w by a small amount dw, how much does the loss change? That question is exactly a directional derivative β€” and jax.jvp answers it in one forward pass without constructing the full gradient vector.

This forward-mode sensitivity is the backbone of:

  • Trust-region methods and Gauss-Newton steps (Hessian-vector products via jvp(grad(f), ...)).
  • Sensitivity analysis in scientific computing and differentiable simulation.
  • Forward-AD probing to verify your backward pass manually.

For a single scalar input perturbed along a single direction, jvp is cheaper than grad(loss)(w) * dw because it avoids materialising the full gradient (though for a single scalar the difference is tiny). The advantage grows with input dimension.

Worked mini-example

import jax, jax.numpy as jnp

x = jnp.array([1.0, 2.0])
b = 0.0

def loss(w):
    return jnp.sum((w * x + b) ** 2)

# loss(1.0) = 1^2 + 2^2 = 5.0
# d/dw loss = 2*sum(x^2) = 2*(1+4) = 10.0   β†’ with dw=1 β†’ tangent = 10.0
primal, tangent = jax.jvp(loss, (1.0,), (1.0,))
# primal  β†’ 5.0
# tangent β†’ 10.0

We return only the tangent (second element) β€” the change in loss per unit change in w along direction dw.

Common pitfalls

  • jvp is forward-mode β€” cheap when input dimension is small (one scalar w here). If you have millions of parameters, jax.grad (reverse mode) is more efficient.
  • Tangent shape must match input shape: dw must be a Python float (scalar) if w is a Python float. A shape mismatch raises an error.
  • Don’t confuse with jax.grad: jax.grad(loss)(w) gives d/dw loss evaluated at w, but you still need to multiply by dw. jvp does this dot product for free: tangent = (d/dw loss) Β· dw.
  • Unused primal: _, tangent = jax.jvp(...) β€” the underscore discards the function value. We only need the directional derivative here.

Problem

Implement sensitivity(w, b, x, dw) that returns the directional derivative of loss(w) = sum((wΒ·x + b)Β²) at w along direction dw.

Formally: sensitivity = d/dw [sum((w*x + b)^2)] Β· dw

Use jax.jvp β€” do not compute the derivative by hand.

  • w: Python float β€” the weight scalar.
  • b: Python float β€” the bias scalar.
  • x: 1-D jax array β€” the data vector.
  • dw: Python float β€” the perturbation direction.

Returns: scalar.

Hints

jax jvp sensitivity

Sign in to attempt this problem and view the solution.