medium primitives

vmap(grad) vs grad(sum(vmap))

Why this matters

When training a neural network, you often need the gradient of the total batch loss. There are two equivalent ways to compute it:

  • Way (a): compute per-example gradients with vmap(grad(f)), then .sum() โ€” useful when you also need the individual gradients (e.g., for DP-SGD clipping).
  • Way (b): vectorize the forward pass with vmap(f), sum the losses, then differentiate once with grad(...) โ€” the standard training loop pattern.

By the linearity of differentiation, these two routes produce the identical scalar. Understanding this equivalence helps you choose the right tool: per-example for DP-SGD, single-grad for normal training. Verifying equality numerically is a useful sanity check when learning the JAX autodiff API.

Worked mini-example

import jax
import jax.numpy as jnp

def per_example_loss(w, x):
    return (w * x) ** 2

w = 1.0
x_batch = jnp.array([1.0, 2.0])

# Way (a): per-example grads โ†’ sum
per_grads = jax.vmap(jax.grad(per_example_loss, argnums=0),
                     in_axes=(None, 0))(w, x_batch)
# per_grads โ†’ [2.0, 8.0]
way_a = per_grads.sum()   # 10.0

# Way (b): vectorized forward โ†’ sum โ†’ grad
way_b = jax.grad(
    lambda w: jax.vmap(per_example_loss, in_axes=(None, 0))(w, x_batch).sum()
)(w)                       # 10.0

assert jnp.isclose(way_a, way_b)   # True

Common pitfalls

  • vmap(grad(f))(w, x_batch) returns shape (N,) โ€” it is a 1-D array of per-example gradients. Call .sum() to collapse to a scalar before comparing with way (b).
  • grad(vmap(f).sum())(w) already returns a scalar โ€” do NOT call .sum() again on way (b); it is already a 0-D array.
  • in_axes must match scalar vs. array args โ€” in_axes=(None, 0) broadcasts w (scalar) and maps over axis 0 of x_batch.
  • The equality holds for any differentiable loss โ€” it is a consequence of linearity, not a coincidence of the specific function.

Problem

Implement equivalent_two_ways(w, x_batch) that:

  1. Computes way (a): vmap(grad(per_example_loss))(w, x_batch).sum().
  2. Computes way (b): grad(lambda w: vmap(per_example_loss)(w, x_batch).sum())(w).
  3. Returns jnp.array([way_a, way_b]) โ€” both entries should be equal.

Use per_example_loss(w, x) = (w * x) ** 2.

  • w: scalar.
  • x_batch: 1-D jax array of shape (N,).

Returns: 1-D jax array of shape (2,) โ€” [way_a, way_b].

Hints

jax vmap grad composition

Sign in to attempt this problem and view the solution.