medium primitives

2-D Scatter via .at[].add()

Why this matters

2-D scatter-add is fundamental to many ML operations: building histograms over 2-D bins, materializing sparse Jacobians, accumulating gradients into a parameter matrix from sparse contributions, and computing per-cell statistics. The .at[rows, cols].add(vals) pattern extends the 1-D scatter to N dimensions trivially โ€” JAXโ€™s functional update API handles arbitrary index arrays as long as they share the same shape.

Worked mini-example

import jax.numpy as jnp
out = jnp.zeros((3, 3))
row = jnp.array([0, 1, 0])
col = jnp.array([1, 1, 1])
val = jnp.array([1.0, 2.0, 3.0])
out = out.at[row, col].add(val)
# out[0,1] = 1.0 + 3.0 = 4.0  (rows 0 and 2 of indices both target [0,1])
# out[1,1] = 2.0
# โ†’ [[0,4,0], [0,2,0], [0,0,0]]

Common pitfalls

  • Mismatched lengths: row_indices, col_indices, and values must all be 1-D arrays of the same length.
  • Non-int indices: cast .astype(jnp.int32) for both row and col indices before using them in .at[].
  • Out-of-bounds indices: by default, mode="promise_in_bounds" โ€” OOB indices produce undefined behavior. Use mode="drop" to silently skip OOB writes, or mode="clip" to clamp them to the valid range. Example: out.at[row, col].add(val, mode="drop").
  • Building out_shape from a JAX array: jnp.zeros(jax_array) will error โ€” jnp.zeros expects a Python tuple or ints. Cast explicitly: int(out_shape[0]), int(out_shape[1]).

Problem

Implement scatter_2d_indices(out_shape, row_indices, col_indices, values) that creates a zero-filled 2-D array of shape (H, W) and scatter-adds values[k] at position [row_indices[k], col_indices[k]] for each k, returning the result. Repeated (row, col) pairs accumulate.

Two illustrative examples (not from the test set):

  • out_shape = [2, 3], row = [0, 1], col = [2, 0], val = [7.0, 3.0] โ†’ [[0, 0, 7], [3, 0, 0]]
  • out_shape = [2, 2], row = [0, 0, 1], col = [0, 0, 1], val = [1.0, 4.0, 9.0] โ†’ [[5, 0], [0, 9]] (1.0 + 4.0 accumulated at [0, 0])

Hints

jax scatter at-add

Sign in to attempt this problem and view the solution.