Gather elements from a 2D tensor using an index tensor along axis 1.
For each row i, select the elements at the column positions specified by indices[i].
Input:
x: A 2D tensor of shape (m, n) indices: A 2D integer tensor of shape (m, k) where each value is a column index in [0, n)
Output: A 2D tensor of shape (m, k) with gathered elements.
API Reference:
torch.gather(x, dim=1, index=indices) jnp.take_along_axis(x, indices, axis=1)