easy primitives

Functional Update with .at[].add()

Why this matters

Scatter-add is one of JAXโ€™s most-used patterns: many element-wise contributions accumulating into a single output. Examples include histogram counting, embedding-table backward pass (gradients aggregate at the same row index), gather-then-scatter aggregations, and sparse matrix construction.

The .at[idx].add(val) API is the canonical primitive for this. Unlike .set(), the result is well-defined for repeated indices โ€” XLA atomic scatter-add gives deterministic accumulation regardless of hardware.

Worked mini-example

import jax.numpy as jnp

arr     = jnp.zeros(5)
indices = jnp.array([0, 2, 0, 3])
values  = jnp.array([1.0, 2.0, 3.0, 4.0])

arr.at[indices].add(values)
# โ†’ [4.0, 0.0, 2.0, 4.0, 0.0]
#     ^ 1.0 + 3.0 (index 0 appears twice โ€” contributions accumulate)

# Compare with .set(): would only retain the LAST write at index 0 = 3.0.

Common pitfalls

  • Confusing .set() and .add(): .set() overwrites; .add() accumulates. Pick the right one.
  • Forgetting to assign the return: same as .set() โ€” .add(...) returns a new array, it does not mutate in place.
  • Non-int indices: cast indices.astype(jnp.int32) first.
  • Out-of-bounds indices: by default, JAX wraps; use mode="drop" or mode="promise_in_bounds" to change behavior.

Problem

Implement accumulate_at_indices(arr, indices, values) that adds each value in values into arr at the corresponding index in indices, returning the new array. Repeated indices accumulate โ€” all contributions are summed.

Two illustrative examples (not from the test set):

  • arr = jnp.array([0.0, 0.0, 0.0]), indices = [0.0, 0.0, 1.0], values = [1.0, 2.0, 5.0] โ†’ [3.0, 5.0, 0.0] (1.0 + 2.0 accumulated at index 0; 5.0 at index 1)
  • arr = jnp.array([10.0, 20.0, 30.0]), indices = [2.0], values = [-5.0] โ†’ [10.0, 20.0, 25.0]

Hints

jax functional-update at-add

Sign in to attempt this problem and view the solution.