medium framework

Scatter Add

Create a target tensor of zeros and accumulate (add) source values at specified indices.

Given a 1D source tensor src, a 1D index tensor indices, and a target size n, create a zero tensor of size n and add src[i] to target[indices[i]] for each i. If multiple source values map to the same target index, they should be summed.

Input:

  • src: A 1D tensor of shape (k,) — values to scatter
  • indices: A 1D integer tensor of shape (k,) — target positions
  • n: The size of the output tensor

Output: A 1D tensor of shape (n,) with accumulated values.

API Reference:

  • PyTorch: target.scatter_add_(0, indices, src) or torch.zeros(n).scatter_add(0, indices, src)
  • JAX: jnp.zeros(n).at[indices].add(src)

Hints

scatter scatter_add torch.scatter_add jnp.at
Detecting runtime...