medium framework

Einstein Summation

Use Einstein summation notation to compute the batch-wise trace of a stack of square matrices.

The trace is the sum of diagonal elements. For a batch of matrices of shape (batch, n, n), return a 1D tensor of shape (batch,) with each element being the trace of the corresponding matrix.

Input: A 3D tensor x of shape (batch, n, n).

Output: A 1D tensor of shape (batch,) containing the trace of each matrix.

API Reference:

  • PyTorch: torch.einsum('bii->b', x)
  • JAX: jnp.einsum('bii->b', x)

Hints

einsum trace torch.einsum jnp.einsum
Detecting runtime...