medium primitives

Dynamic Slice

Why this matters

lax.dynamic_slice(x, start_indices, slice_sizes) slices an array with traced (runtime-variable) start indices. This is essential under jit: plain Python slicing x[start:start+size] requires concrete indices and will error when start is a JAX tracer.

dynamic_slice is the XLA-idiomatic solution β€” the start indices may be arbitrary integer arrays computed at runtime, while slice_sizes must be static Python ints known at trace time.

Worked mini-example

from jax import lax
import jax.numpy as jnp

x = jnp.arange(10, dtype=jnp.float32)
# Extract 3 elements starting at index 5
window = lax.dynamic_slice(x, (jnp.int32(5),), (3,))
# β†’ [5.0, 6.0, 7.0]

The first argument is the operand, the second is a tuple of int32 start indices (one per axis), and the third is a tuple of static slice sizes.

Common pitfalls

  • Python slicing under jit. x[start:start+size] raises ConcretizationTypeError when start is a tracer. Always use lax.dynamic_slice for index-dependent slices under jit.
  • Float start indices. Cast to jnp.int32 explicitly; passing a float tracer raises a type error.
  • Static slice_sizes. Unlike start_indices, slice_sizes must be a tuple of Python ints β€” not JAX arrays. Use int(window_size) if the caller passes a float.
  • Out-of-bounds clamping. JAX clamps out-of-bounds indices rather than raising β€” you won’t get an error, just silently wrapped data.

Problem

Implement extract_window(x, start_idx, window_size) that extracts a contiguous window from a 1-D array using lax.dynamic_slice.

  • x: 1-D jax array.
  • start_idx: scalar β€” cast to jnp.int32.
  • window_size: scalar β€” cast to int (must be static).
  • Returns: 1-D array of length window_size.

Hints

jax dynamic-slice lax

Sign in to attempt this problem and view the solution.