Implement prefix tuning from “Prefix-Tuning: Optimizing Continuous Prompts for Generation” (Li & Liang, 2021).
In prefix tuning, learnable prefix vectors are prepended to the key and value sequences in attention. The original model weights are frozen; only prefix parameters are trained.
Given:
Q: shape (seq_len, d_k) — queries (from frozen model) K: shape (seq_len, d_k) — keys (from frozen model) V: shape (seq_len, d_k) — values (from frozen model) prefix_K: shape (prefix_len, d_k) — learnable prefix keys prefix_V: shape (prefix_len, d_k) — learnable prefix values Steps:
(prefix_len + seq_len, d_k) (prefix_len + seq_len, d_k)
Output: Tensor of shape (seq_len, d_k).