We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Functional Update with .at[].set()
Why this matters
JAX arrays are immutable. arr[idx] = val (numpy/PyTorch syntax) doesn’t
work — JAX raises TypeError: 'ArrayImpl' object does not support item assignment.
The functional-update API is
arr.at[idx].set(val), which returns a new array. The original arr is
unchanged.
This pattern enables JAX’s pure-function discipline: every “mutation” is just
constructing a new value. Under jax.jit, XLA can often lower .at[].set()
to an in-place buffer update at the hardware level — you get the ergonomics of
functional code with the performance of mutation.
Worked mini-example
# ❌ Doesn't work — JAX arrays are immutable
arr = jnp.array([1.0, 2.0, 3.0])
arr[1] = 99.0 # TypeError
# ✅ Functional update — returns a new array
arr2 = arr.at[1].set(99.0)
# arr is still [1.0, 2.0, 3.0]
# arr2 is [1.0, 99.0, 3.0]
Compare with PyTorch: arr[1] = 99.0 mutates in place. With JAX you must
capture the return value as a new binding.
Common pitfalls
-
Forgetting to assign the result:
arr.at[1].set(99.0)returns a new array. If you don’t capture it (arr = arr.at[1].set(99.0)), the update is lost. -
Trying
arr[1] = 99.0:TypeError. There is no__setitem__on JAX arrays. -
Repeated indices with
.set(): when two index positions are the same, the result is implementation-defined — eager execution typically gives “last write wins”, but XLA scatter underjax.jitmay pick a different write. Don’t rely on a specific tiebreaker. Use.add()if you need explicit accumulation. -
Float indices:
arr.at[idx].set(val)requires integer indices. Cast withidx.astype(jnp.int32)if your data delivers floats.
Problem
Implement update_at_indices(arr, indices, values) that updates arr at each
position listed in indices with the corresponding entry in values, returning
the new array. arr must remain unchanged.
Two illustrative examples (not from the test set):
-
arr = jnp.array([5.0, 5.0, 5.0]),indices = [2.0],values = [0.0]→[5.0, 5.0, 0.0] -
arr = jnp.array([1.0, 2.0, 3.0, 4.0]),indices = [0.0, 3.0],values = [100.0, -100.0]→[100.0, 2.0, 3.0, -100.0]
Hints
Sign in to attempt this problem and view the solution.