We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Functional Update with .at[].add()
Why this matters
Scatter-add is one of JAXโs most-used patterns: many element-wise contributions accumulating into a single output. Examples include histogram counting, embedding-table backward pass (gradients aggregate at the same row index), gather-then-scatter aggregations, and sparse matrix construction.
The .at[idx].add(val) API is the canonical primitive for this. Unlike
.set(), the result is well-defined for repeated indices โ XLA atomic
scatter-add gives deterministic accumulation regardless of hardware.
Worked mini-example
import jax.numpy as jnp
arr = jnp.zeros(5)
indices = jnp.array([0, 2, 0, 3])
values = jnp.array([1.0, 2.0, 3.0, 4.0])
arr.at[indices].add(values)
# โ [4.0, 0.0, 2.0, 4.0, 0.0]
# ^ 1.0 + 3.0 (index 0 appears twice โ contributions accumulate)
# Compare with .set(): would only retain the LAST write at index 0 = 3.0.
Common pitfalls
-
Confusing
.set()and.add():.set()overwrites;.add()accumulates. Pick the right one. -
Forgetting to assign the return: same as
.set()โ.add(...)returns a new array, it does not mutate in place. -
Non-int indices: cast
indices.astype(jnp.int32)first. -
Out-of-bounds indices: by default, JAX wraps; use
mode="drop"ormode="promise_in_bounds"to change behavior.
Problem
Implement accumulate_at_indices(arr, indices, values) that adds each value
in values into arr at the corresponding index in indices, returning the
new array. Repeated indices accumulate โ all contributions are summed.
Two illustrative examples (not from the test set):
-
arr = jnp.array([0.0, 0.0, 0.0]),indices = [0.0, 0.0, 1.0],values = [1.0, 2.0, 5.0]โ[3.0, 5.0, 0.0](1.0 + 2.0 accumulated at index 0; 5.0 at index 1) -
arr = jnp.array([10.0, 20.0, 30.0]),indices = [2.0],values = [-5.0]โ[10.0, 20.0, 25.0]
Hints
Sign in to attempt this problem and view the solution.