hard framework

Efficient Attention with Masking

Implement causal (autoregressive) scaled dot-product attention using framework APIs.

Apply a causal mask so that position i can only attend to positions <= i. Masked positions should be filled with -inf before the softmax.

Input:

  • Q: Query tensor of shape (seq_len, d_k)
  • K: Key tensor of shape (seq_len, d_k)
  • V: Value tensor of shape (seq_len, d_v)

Output: A tensor of shape (seq_len, d_v) — the attention output.

Steps:

  1. Compute scores = Q @ K^T / sqrt(d_k)
  2. Create a causal mask (upper triangle = True)
  3. Fill masked positions with -inf
  4. Apply softmax along the last dimension
  5. Multiply by V

API Reference:

  • PyTorch: torch.triu, masked_fill, torch.softmax
  • JAX: jnp.triu, jnp.where, jax.nn.softmax

Hints

attention causal-mask torch.triu jnp.triu softmax
Detecting runtime...