We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Per-Example Gradients via vmap(grad(...))
Why this matters
Per-example gradients are essential for several modern ML techniques:
- Differentially private SGD (DP-SGD) β clip each exampleβs gradient norm independently before averaging, so no single example has outsized influence.
- Importance sampling β reweight gradients by example difficulty.
- Gradient outlier detection β identify training examples whose gradients are anomalously large.
- Influence functions β measure how a single training point affects model predictions.
The naive approach jax.grad(lambda w: sum_loss(w, x_batch, y_batch))
aggregates all examples first, losing per-example information. The correct
approach uses vmap(grad(...)): define the loss for a single example,
take its gradient, then vectorize over the batch dimension.
vmap(grad(f)) returns a result of shape (N,) β one gradient per
example β at roughly the same cost as one batch gradient, because vmap
vectorizes rather than loops.
Worked mini-example
import jax
import jax.numpy as jnp
def per_example_loss(w, b, x, y):
return (w * x - y) ** 2 # scalar for one example
grad_fn = jax.grad(per_example_loss, argnums=0)
w, b = 1.0, 0.0
x_batch = jnp.array([1.0, 2.0])
y_batch = jnp.array([0.0, 0.0])
# grad_fn(w, b, x, y) computes one gradient.
# vmap maps it over the batch.
grads = jax.vmap(grad_fn, in_axes=(None, None, 0, 0))(w, b, x_batch, y_batch)
# grads β [2.0, 8.0] (per-example: 2*(w*x - y)*x)
Common pitfalls
-
Forgetting
in_axesβ by default, vmap maps over axis 0 of EVERY argument. Scalarswandbhave no batch axis, so you must passin_axes=(None, None, 0, 0).Nonebroadcasts the scalar;0vectorizes over axis 0. -
Taking grad of the batched loss β
grad(sum_loss)collapses to one number and loses per-example resolution. Always define the loss for a single example first, then vmap. - Assuming vmap is slow β vmap vectorizes the computation; it does not Python-loop. Per-example grads via vmap cost approximately one batch gradient, not N individual gradients.
Problem
Implement per_example_grads(w, b, x_batch, y_batch) that returns a
1-D array of shape (N,) β one gradient of (w*x - y)^2 w.r.t. w per
example. Do not sum or average β return the raw per-example values.
-
w: scalar β the weight. -
b: scalar β unused, kept for signature parity. -
x_batch,y_batch: 1-D jax arrays of shape(N,).
Returns: 1-D jax array of shape (N,).
Hints
Sign in to attempt this problem and view the solution.