We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
framework
Scatter Add
Create a target tensor of zeros and accumulate (add) source values at specified indices.
Given a 1D source tensor src, a 1D index tensor indices, and a target size n,
create a zero tensor of size n and add src[i] to target[indices[i]] for each i.
If multiple source values map to the same target index, they should be summed.
Input:
-
src: A 1D tensor of shape(k,)โ values to scatter -
indices: A 1D integer tensor of shape(k,)โ target positions -
n: The size of the output tensor
Output: A 1D tensor of shape (n,) with accumulated values.
API Reference:
-
PyTorch:
target.scatter_add_(0, indices, src)ortorch.zeros(n).scatter_add(0, indices, src) -
JAX:
jnp.zeros(n).at[indices].add(src)
Hints
scatter
scatter_add
torch.scatter_add
jnp.at
Sign in to attempt this problem and view the solution.