Use Einstein summation notation to compute the batch-wise trace of a stack of square matrices.
The trace is the sum of diagonal elements. For a batch of matrices of shape (batch, n, n),
return a 1D tensor of shape (batch,) with each element being the trace of the corresponding matrix.
Input: A 3D tensor x of shape (batch, n, n).
Output: A 1D tensor of shape (batch,) containing the trace of each matrix.
API Reference:
torch.einsum('bii->b', x) jnp.einsum('bii->b', x)