hard framework

Parallel Map with vmap

Use vmap to compute the per-sample gradient of a loss function across a batch.

Given a weight vector w of shape (d,) and a batch of inputs X of shape (batch, d) with targets y of shape (batch,), compute the gradient of the squared loss L_i = (w . X_i - y_i)^2 with respect to w for each sample individually.

Return the per-sample gradients as a matrix of shape (batch, d).

The gradient for sample i: dL_i/dw = 2 * (w . X_i - y_i) * X_i

Input:

  • w: A 1D tensor of shape (d,)
  • X: A 2D tensor of shape (batch, d)
  • y: A 1D tensor of shape (batch,)

Output: A 2D tensor of shape (batch, d) — per-sample gradients.

API Reference:

  • JAX: jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0))(w, X, y)
  • PyTorch: torch.vmap(grad_fn)(w_expanded, X, y)

Hints

vmap per-sample-gradient jax.vmap jax.grad torch.vmap
Detecting runtime...