Use vmap to compute the per-sample gradient of a loss function across a batch.
Given a weight vector w of shape (d,) and a batch of inputs X of shape (batch, d)
with targets y of shape (batch,), compute the gradient of the squared loss
L_i = (w . X_i - y_i)^2 with respect to w for each sample individually.
Return the per-sample gradients as a matrix of shape (batch, d).
The gradient for sample i: dL_i/dw = 2 * (w . X_i - y_i) * X_i
Input:
w: A 1D tensor of shape (d,) X: A 2D tensor of shape (batch, d) y: A 1D tensor of shape (batch,)
Output: A 2D tensor of shape (batch, d) — per-sample gradients.
API Reference:
jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0))(w, X, y) torch.vmap(grad_fn)(w_expanded, X, y)