easy primitives

Functional Update with .at[].set()

Why this matters

JAX arrays are immutable. arr[idx] = val (numpy/PyTorch syntax) doesn’t work — JAX raises TypeError: 'ArrayImpl' object does not support item assignment. The functional-update API is arr.at[idx].set(val), which returns a new array. The original arr is unchanged.

This pattern enables JAX’s pure-function discipline: every “mutation” is just constructing a new value. Under jax.jit, XLA can often lower .at[].set() to an in-place buffer update at the hardware level — you get the ergonomics of functional code with the performance of mutation.

Worked mini-example

# ❌ Doesn't work — JAX arrays are immutable
arr = jnp.array([1.0, 2.0, 3.0])
arr[1] = 99.0           # TypeError

# ✅ Functional update — returns a new array
arr2 = arr.at[1].set(99.0)
# arr  is still [1.0, 2.0, 3.0]
# arr2 is        [1.0, 99.0, 3.0]

Compare with PyTorch: arr[1] = 99.0 mutates in place. With JAX you must capture the return value as a new binding.

Common pitfalls

  • Forgetting to assign the result: arr.at[1].set(99.0) returns a new array. If you don’t capture it (arr = arr.at[1].set(99.0)), the update is lost.
  • Trying arr[1] = 99.0: TypeError. There is no __setitem__ on JAX arrays.
  • Repeated indices with .set(): when two index positions are the same, the result is implementation-defined — eager execution typically gives “last write wins”, but XLA scatter under jax.jit may pick a different write. Don’t rely on a specific tiebreaker. Use .add() if you need explicit accumulation.
  • Float indices: arr.at[idx].set(val) requires integer indices. Cast with idx.astype(jnp.int32) if your data delivers floats.

Problem

Implement update_at_indices(arr, indices, values) that updates arr at each position listed in indices with the corresponding entry in values, returning the new array. arr must remain unchanged.

Two illustrative examples (not from the test set):

  • arr = jnp.array([5.0, 5.0, 5.0]), indices = [2.0], values = [0.0][5.0, 5.0, 0.0]
  • arr = jnp.array([1.0, 2.0, 3.0, 4.0]), indices = [0.0, 3.0], values = [100.0, -100.0][100.0, 2.0, 3.0, -100.0]

Hints

jax functional-update at-set

Sign in to attempt this problem and view the solution.