We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Dynamic Update Slice
Why this matters
lax.dynamic_update_slice(operand, update, start_indices) is the inverse
of lax.dynamic_slice โ it writes an update array into an operand
starting at traced (runtime-variable) indices. Because JAX arrays are
immutable, this returns a new array rather than mutating in place.
This is the JIT-compatible replacement for x[start:end] = update, which
requires concrete indices. Use it whenever you need to scatter a block of
values into a larger buffer inside a jit-compiled function.
Worked mini-example
from jax import lax
import jax.numpy as jnp
x = jnp.zeros(6, dtype=jnp.float32)
patch = jnp.array([7.0, 8.0])
result = lax.dynamic_update_slice(x, patch, (jnp.int32(2),))
# โ [0.0, 0.0, 7.0, 8.0, 0.0, 0.0]
The output has the same shape as operand. The update array is
written starting at the position given by start_indices.
Common pitfalls
-
Shape of output. The result has the shape of
operand, notupdate. -
update must fit.
len(update) + start_idxmust not exceedlen(operand)โ out-of-bounds behaviour is undefined and may silently clamp. -
Float start index. Cast to
jnp.int32; passing a float tracer raises a type error. -
Immutability. The call returns a new array โ the original
operandis unchanged. Forgetting to capture the return value is a common mistake.
Problem
Implement update_window(x, start_idx, update) that writes update into
x starting at start_idx using lax.dynamic_update_slice.
-
x: 1-D jax array. -
start_idx: scalar โ cast tojnp.int32. -
update: 1-D jax array to insert. -
Returns: 1-D array same shape as
x.
Hints
Sign in to attempt this problem and view the solution.