We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
2-D Scatter via .at[].add()
Why this matters
2-D scatter-add is fundamental to many ML operations: building histograms
over 2-D bins, materializing sparse Jacobians, accumulating gradients into
a parameter matrix from sparse contributions, and computing per-cell
statistics. The .at[rows, cols].add(vals) pattern extends the 1-D scatter
to N dimensions trivially โ JAXโs functional update API handles arbitrary
index arrays as long as they share the same shape.
Worked mini-example
import jax.numpy as jnp
out = jnp.zeros((3, 3))
row = jnp.array([0, 1, 0])
col = jnp.array([1, 1, 1])
val = jnp.array([1.0, 2.0, 3.0])
out = out.at[row, col].add(val)
# out[0,1] = 1.0 + 3.0 = 4.0 (rows 0 and 2 of indices both target [0,1])
# out[1,1] = 2.0
# โ [[0,4,0], [0,2,0], [0,0,0]]
Common pitfalls
-
Mismatched lengths:
row_indices,col_indices, andvaluesmust all be 1-D arrays of the same length. -
Non-int indices: cast
.astype(jnp.int32)for both row and col indices before using them in.at[]. -
Out-of-bounds indices: by default,
mode="promise_in_bounds"โ OOB indices produce undefined behavior. Usemode="drop"to silently skip OOB writes, ormode="clip"to clamp them to the valid range. Example:out.at[row, col].add(val, mode="drop"). -
Building
out_shapefrom a JAX array:jnp.zeros(jax_array)will error โjnp.zerosexpects a Python tuple or ints. Cast explicitly:int(out_shape[0]), int(out_shape[1]).
Problem
Implement scatter_2d_indices(out_shape, row_indices, col_indices, values)
that creates a zero-filled 2-D array of shape (H, W) and scatter-adds
values[k] at position [row_indices[k], col_indices[k]] for each k,
returning the result. Repeated (row, col) pairs accumulate.
Two illustrative examples (not from the test set):
-
out_shape = [2, 3],row = [0, 1],col = [2, 0],val = [7.0, 3.0]โ[[0, 0, 7], [3, 0, 0]] -
out_shape = [2, 2],row = [0, 0, 1],col = [0, 0, 1],val = [1.0, 4.0, 9.0]โ[[5, 0], [0, 9]](1.0 + 4.0 accumulated at [0, 0])
Hints
Sign in to attempt this problem and view the solution.