We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
take_along_axis (Gather)
Why this matters
jnp.take_along_axis is the JAX primitive for per-row (or per-slice)
index-based gather. It drives beam search (selecting the top-k token logits
for each beam), sorting-based operations (gathering sorted elements), and any
pattern where you need row-specific column picks rather than a global
flat-index gather.
Unlike jnp.take โ which indexes a flattened axis โ take_along_axis
operates along a chosen axis while keeping all other axes intact. The output
shape exactly matches the indices shape along the gathered axis.
Worked mini-example
import jax.numpy as jnp
matrix = jnp.array([[10, 20, 30, 40],
[50, 60, 70, 80]]) # shape (2, 4)
indices = jnp.array([[3, 0],
[2, 1]], dtype=jnp.int32) # shape (2, 2)
out = jnp.take_along_axis(matrix, indices, axis=1)
# out[0] = [matrix[0,3], matrix[0,0]] = [40, 10]
# out[1] = [matrix[1,2], matrix[1,1]] = [70, 60]
# out.shape โ (2, 2)
Common pitfalls
-
ndim mismatch:
indicesmust have the same number of dimensions asmatrix. Ifmatrixis 2-D,indicesmust also be 2-D โ not 1-D. -
Float indices: indices must be integer-typed. Cast with
.astype(jnp.int32)when indices arrive as floats (e.g., from JSON test input). -
Confusing with
jnp.take:jnp.takegathers from a flattened axis;take_along_axisgathers per-slice. They give different results for 2-D inputs. - Out-of-bounds indices: JAX on GPU/TPU silently wraps or clamps; on CPU it may raise. Keep indices in range.
Problem
Implement gather_top_k_per_row(matrix, indices) that gathers elements from
matrix along axis 1 using indices.
-
matrix: 2-D array of shape (N, K). -
indices: 2-D integer array of shape (N, M) โ for each row i, column positions to pick. -
Returns: 2-D array of shape (N, M) where
out[i, j] = matrix[i, indices[i, j]].
Two illustrative examples (not from the test set):
-
matrix = [[1, 2, 3]], indices = [[2, 0]]:out = [[3, 1]] -
matrix = [[10, 20], [30, 40]], indices = [[1, 0], [0, 0]]:out = [[20, 10], [30, 30]]
Hints
Sign in to attempt this problem and view the solution.