medium primitives

Per-Example Gradients via vmap(grad(...))

Why this matters

Per-example gradients are essential for several modern ML techniques:

  • Differentially private SGD (DP-SGD) β€” clip each example’s gradient norm independently before averaging, so no single example has outsized influence.
  • Importance sampling β€” reweight gradients by example difficulty.
  • Gradient outlier detection β€” identify training examples whose gradients are anomalously large.
  • Influence functions β€” measure how a single training point affects model predictions.

The naive approach jax.grad(lambda w: sum_loss(w, x_batch, y_batch)) aggregates all examples first, losing per-example information. The correct approach uses vmap(grad(...)): define the loss for a single example, take its gradient, then vectorize over the batch dimension.

vmap(grad(f)) returns a result of shape (N,) β€” one gradient per example β€” at roughly the same cost as one batch gradient, because vmap vectorizes rather than loops.

Worked mini-example

import jax
import jax.numpy as jnp

def per_example_loss(w, b, x, y):
    return (w * x - y) ** 2   # scalar for one example

grad_fn = jax.grad(per_example_loss, argnums=0)

w, b = 1.0, 0.0
x_batch = jnp.array([1.0, 2.0])
y_batch = jnp.array([0.0, 0.0])

# grad_fn(w, b, x, y) computes one gradient.
# vmap maps it over the batch.
grads = jax.vmap(grad_fn, in_axes=(None, None, 0, 0))(w, b, x_batch, y_batch)
# grads β†’ [2.0, 8.0]   (per-example: 2*(w*x - y)*x)

Common pitfalls

  • Forgetting in_axes β€” by default, vmap maps over axis 0 of EVERY argument. Scalars w and b have no batch axis, so you must pass in_axes=(None, None, 0, 0). None broadcasts the scalar; 0 vectorizes over axis 0.
  • Taking grad of the batched loss β€” grad(sum_loss) collapses to one number and loses per-example resolution. Always define the loss for a single example first, then vmap.
  • Assuming vmap is slow β€” vmap vectorizes the computation; it does not Python-loop. Per-example grads via vmap cost approximately one batch gradient, not N individual gradients.

Problem

Implement per_example_grads(w, b, x_batch, y_batch) that returns a 1-D array of shape (N,) β€” one gradient of (w*x - y)^2 w.r.t. w per example. Do not sum or average β€” return the raw per-example values.

  • w: scalar β€” the weight.
  • b: scalar β€” unused, kept for signature parity.
  • x_batch, y_batch: 1-D jax arrays of shape (N,).

Returns: 1-D jax array of shape (N,).

Hints

jax per-example-grad vmap

Sign in to attempt this problem and view the solution.