medium primitives

Dynamic Update Slice

Why this matters

lax.dynamic_update_slice(operand, update, start_indices) is the inverse of lax.dynamic_slice โ€” it writes an update array into an operand starting at traced (runtime-variable) indices. Because JAX arrays are immutable, this returns a new array rather than mutating in place.

This is the JIT-compatible replacement for x[start:end] = update, which requires concrete indices. Use it whenever you need to scatter a block of values into a larger buffer inside a jit-compiled function.

Worked mini-example

from jax import lax
import jax.numpy as jnp

x = jnp.zeros(6, dtype=jnp.float32)
patch = jnp.array([7.0, 8.0])
result = lax.dynamic_update_slice(x, patch, (jnp.int32(2),))
# โ†’ [0.0, 0.0, 7.0, 8.0, 0.0, 0.0]

The output has the same shape as operand. The update array is written starting at the position given by start_indices.

Common pitfalls

  • Shape of output. The result has the shape of operand, not update.
  • update must fit. len(update) + start_idx must not exceed len(operand) โ€” out-of-bounds behaviour is undefined and may silently clamp.
  • Float start index. Cast to jnp.int32; passing a float tracer raises a type error.
  • Immutability. The call returns a new array โ€” the original operand is unchanged. Forgetting to capture the return value is a common mistake.

Problem

Implement update_window(x, start_idx, update) that writes update into x starting at start_idx using lax.dynamic_update_slice.

  • x: 1-D jax array.
  • start_idx: scalar โ€” cast to jnp.int32.
  • update: 1-D jax array to insert.
  • Returns: 1-D array same shape as x.

Hints

jax dynamic-update-slice lax

Sign in to attempt this problem and view the solution.