easy framework

Tensor Indexing and Slicing

Extract a sub-tensor from a 2D tensor by selecting specific rows.

Given a 2D tensor x and a list of row indices rows, return the tensor formed by selecting only those rows (in order).

Input:

  • x: A 2D tensor of shape (m, n)
  • rows: A list of integer row indices

Output: A 2D tensor of shape (len(rows), n) with the selected rows.

API Reference:

  • PyTorch: x[indices] or torch.index_select(x, 0, indices)
  • JAX: x[indices] or jnp.take(x, indices, axis=0)

Hints

indexing slicing torch.index_select jnp.take
Detecting runtime...