Implement causal (autoregressive) scaled dot-product attention using framework APIs.
Apply a causal mask so that position i can only attend to positions <= i.
Masked positions should be filled with -inf before the softmax.
Input:
Q: Query tensor of shape (seq_len, d_k) K: Key tensor of shape (seq_len, d_k) V: Value tensor of shape (seq_len, d_v)
Output: A tensor of shape (seq_len, d_v) — the attention output.
Steps:
Q @ K^T / sqrt(d_k) -inf API Reference:
torch.triu, masked_fill, torch.softmax jnp.triu, jnp.where, jax.nn.softmax