We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jvp for Sensitivity Analysis
Why this matters
When tuning a model, you often ask: if I perturb parameter w by a
small amount dw, how much does the loss change? That question is
exactly a directional derivative β and jax.jvp answers it in one
forward pass without constructing the full gradient vector.
This forward-mode sensitivity is the backbone of:
-
Trust-region methods and Gauss-Newton steps (Hessian-vector
products via
jvp(grad(f), ...)). - Sensitivity analysis in scientific computing and differentiable simulation.
- Forward-AD probing to verify your backward pass manually.
For a single scalar input perturbed along a single direction, jvp is
cheaper than grad(loss)(w) * dw because it avoids materialising the
full gradient (though for a single scalar the difference is tiny). The
advantage grows with input dimension.
Worked mini-example
import jax, jax.numpy as jnp
x = jnp.array([1.0, 2.0])
b = 0.0
def loss(w):
return jnp.sum((w * x + b) ** 2)
# loss(1.0) = 1^2 + 2^2 = 5.0
# d/dw loss = 2*sum(x^2) = 2*(1+4) = 10.0 β with dw=1 β tangent = 10.0
primal, tangent = jax.jvp(loss, (1.0,), (1.0,))
# primal β 5.0
# tangent β 10.0
We return only the tangent (second element) β the change in loss per
unit change in w along direction dw.
Common pitfalls
-
jvpis forward-mode β cheap when input dimension is small (one scalarwhere). If you have millions of parameters,jax.grad(reverse mode) is more efficient. -
Tangent shape must match input shape:
dwmust be a Python float (scalar) ifwis a Python float. A shape mismatch raises an error. -
Donβt confuse with
jax.grad:jax.grad(loss)(w)givesd/dw lossevaluated atw, but you still need to multiply bydw.jvpdoes this dot product for free: tangent =(d/dw loss) Β· dw. -
Unused primal:
_, tangent = jax.jvp(...)β the underscore discards the function value. We only need the directional derivative here.
Problem
Implement sensitivity(w, b, x, dw) that returns the directional
derivative of loss(w) = sum((wΒ·x + b)Β²) at w along direction dw.
Formally: sensitivity = d/dw [sum((w*x + b)^2)] Β· dw
Use jax.jvp β do not compute the derivative by hand.
-
w: Python float β the weight scalar. -
b: Python float β the bias scalar. -
x: 1-D jax array β the data vector. -
dw: Python float β the perturbation direction.
Returns: scalar.
Hints
Sign in to attempt this problem and view the solution.