We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
vmap(grad) vs grad(sum(vmap))
Why this matters
When training a neural network, you often need the gradient of the total batch loss. There are two equivalent ways to compute it:
-
Way (a): compute per-example gradients with
vmap(grad(f)), then.sum()โ useful when you also need the individual gradients (e.g., for DP-SGD clipping). -
Way (b): vectorize the forward pass with
vmap(f), sum the losses, then differentiate once withgrad(...)โ the standard training loop pattern.
By the linearity of differentiation, these two routes produce the identical scalar. Understanding this equivalence helps you choose the right tool: per-example for DP-SGD, single-grad for normal training. Verifying equality numerically is a useful sanity check when learning the JAX autodiff API.
Worked mini-example
import jax
import jax.numpy as jnp
def per_example_loss(w, x):
return (w * x) ** 2
w = 1.0
x_batch = jnp.array([1.0, 2.0])
# Way (a): per-example grads โ sum
per_grads = jax.vmap(jax.grad(per_example_loss, argnums=0),
in_axes=(None, 0))(w, x_batch)
# per_grads โ [2.0, 8.0]
way_a = per_grads.sum() # 10.0
# Way (b): vectorized forward โ sum โ grad
way_b = jax.grad(
lambda w: jax.vmap(per_example_loss, in_axes=(None, 0))(w, x_batch).sum()
)(w) # 10.0
assert jnp.isclose(way_a, way_b) # True
Common pitfalls
-
vmap(grad(f))(w, x_batch)returns shape(N,)โ it is a 1-D array of per-example gradients. Call.sum()to collapse to a scalar before comparing with way (b). -
grad(vmap(f).sum())(w)already returns a scalar โ do NOT call.sum()again on way (b); it is already a 0-D array. -
in_axesmust match scalar vs. array args โin_axes=(None, 0)broadcastsw(scalar) and maps over axis 0 ofx_batch. - The equality holds for any differentiable loss โ it is a consequence of linearity, not a coincidence of the specific function.
Problem
Implement equivalent_two_ways(w, x_batch) that:
-
Computes way (a):
vmap(grad(per_example_loss))(w, x_batch).sum(). -
Computes way (b):
grad(lambda w: vmap(per_example_loss)(w, x_batch).sum())(w). -
Returns
jnp.array([way_a, way_b])โ both entries should be equal.
Use per_example_loss(w, x) = (w * x) ** 2.
-
w: scalar. -
x_batch: 1-D jax array of shape(N,).
Returns: 1-D jax array of shape (2,) โ [way_a, way_b].
Hints
Sign in to attempt this problem and view the solution.