We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Dynamic Slice
Why this matters
lax.dynamic_slice(x, start_indices, slice_sizes) slices an array with
traced (runtime-variable) start indices. This is essential under jit:
plain Python slicing x[start:start+size] requires concrete indices and
will error when start is a JAX tracer.
dynamic_slice is the XLA-idiomatic solution β the start indices may be
arbitrary integer arrays computed at runtime, while slice_sizes must be
static Python ints known at trace time.
Worked mini-example
from jax import lax
import jax.numpy as jnp
x = jnp.arange(10, dtype=jnp.float32)
# Extract 3 elements starting at index 5
window = lax.dynamic_slice(x, (jnp.int32(5),), (3,))
# β [5.0, 6.0, 7.0]
The first argument is the operand, the second is a tuple of int32 start indices (one per axis), and the third is a tuple of static slice sizes.
Common pitfalls
-
Python slicing under jit.
x[start:start+size]raisesConcretizationTypeErrorwhenstartis a tracer. Always uselax.dynamic_slicefor index-dependent slices under jit. -
Float start indices. Cast to
jnp.int32explicitly; passing a float tracer raises a type error. -
Static slice_sizes. Unlike start_indices,
slice_sizesmust be a tuple of Python ints β not JAX arrays. Useint(window_size)if the caller passes a float. - Out-of-bounds clamping. JAX clamps out-of-bounds indices rather than raising β you wonβt get an error, just silently wrapped data.
Problem
Implement extract_window(x, start_idx, window_size) that extracts a
contiguous window from a 1-D array using lax.dynamic_slice.
-
x: 1-D jax array. -
start_idx: scalar β cast tojnp.int32. -
window_size: scalar β cast toint(must be static). -
Returns: 1-D array of length
window_size.
Hints
Sign in to attempt this problem and view the solution.