medium primitives

take_along_axis (Gather)

Why this matters

jnp.take_along_axis is the JAX primitive for per-row (or per-slice) index-based gather. It drives beam search (selecting the top-k token logits for each beam), sorting-based operations (gathering sorted elements), and any pattern where you need row-specific column picks rather than a global flat-index gather.

Unlike jnp.take โ€” which indexes a flattened axis โ€” take_along_axis operates along a chosen axis while keeping all other axes intact. The output shape exactly matches the indices shape along the gathered axis.

Worked mini-example

import jax.numpy as jnp

matrix  = jnp.array([[10, 20, 30, 40],
                      [50, 60, 70, 80]])   # shape (2, 4)
indices = jnp.array([[3, 0],
                      [2, 1]], dtype=jnp.int32)  # shape (2, 2)

out = jnp.take_along_axis(matrix, indices, axis=1)
# out[0] = [matrix[0,3], matrix[0,0]] = [40, 10]
# out[1] = [matrix[1,2], matrix[1,1]] = [70, 60]
# out.shape โ†’ (2, 2)

Common pitfalls

  • ndim mismatch: indices must have the same number of dimensions as matrix. If matrix is 2-D, indices must also be 2-D โ€” not 1-D.
  • Float indices: indices must be integer-typed. Cast with .astype(jnp.int32) when indices arrive as floats (e.g., from JSON test input).
  • Confusing with jnp.take: jnp.take gathers from a flattened axis; take_along_axis gathers per-slice. They give different results for 2-D inputs.
  • Out-of-bounds indices: JAX on GPU/TPU silently wraps or clamps; on CPU it may raise. Keep indices in range.

Problem

Implement gather_top_k_per_row(matrix, indices) that gathers elements from matrix along axis 1 using indices.

  • matrix: 2-D array of shape (N, K).
  • indices: 2-D integer array of shape (N, M) โ€” for each row i, column positions to pick.
  • Returns: 2-D array of shape (N, M) where out[i, j] = matrix[i, indices[i, j]].

Two illustrative examples (not from the test set):

  • matrix = [[1, 2, 3]], indices = [[2, 0]]: out = [[3, 1]]

  • matrix = [[10, 20], [30, 40]], indices = [[1, 0], [0, 0]]: out = [[20, 10], [30, 30]]

Hints

jax indexing take gather

Sign in to attempt this problem and view the solution.