Perform batched matrix multiplication on two 3D tensors.
Given A of shape (batch, m, k) and B of shape (batch, k, n),
compute C of shape (batch, m, n) where C[i] = A[i] @ B[i] for each batch element.
Input:
A: A 3D tensor of shape (batch, m, k) B: A 3D tensor of shape (batch, k, n)
Output: A 3D tensor of shape (batch, m, n).
API Reference:
torch.bmm(A, B) or A @ B jnp.matmul(A, B) or A @ B