medium end_to_end

Train LoRA Adapter End-to-End

Fine-tune a LoRA adapter on a linear regression task β€” implemented from scratch. No peft, no transformers; you own the forward pass, the gradient, and the SGD update.

What is LoRA?

LoRA (Hu et al. 2021, β€œLoRA: Low-Rank Adaptation of Large Language Models”) decomposes a weight update into two small matrices:

$$W = W_{\text{base}} + A B$$

where A is (d_in, r) and B is (r, d_out) with rank r β‰ͺ min(d_in, d_out). This reduces trainable parameters from d_in Γ— d_out to r(d_in + d_out).

Why freeze the base weights?

w_base is kept frozen β€” no gradient is computed or applied to it. Only A and B are updated. This is the whole point of LoRA: you adapt a pre-trained model cheaply by training only the low-rank adapter.

Forward pass

effective_w = w_base + a @ b          # (d_in, d_out)
y_hat       = x @ effective_w         # (N, d_out)
loss        = mean((y_hat - y)**2)    # MSE

Gradient derivation

MSE gradient w.r.t. effective_w:

$$G = \frac{2}{N} X^\top (\hat{y} - y) \quad \text{shape } (d_{\text{in}}, d_{\text{out}})$$

Because effective_w = w_base + A B, chain rule gives:

$$\nabla_A = G B^\top, \quad \nabla_B = A^\top G$$

SGD update

a -= lr * grad_a
b -= lr * grad_b

Algorithm

a, b = a0, b0
for _ in range(n_steps):
    effective_w = w_base + a @ b
    y_hat = x @ effective_w
    G = (2/N) * x.T @ (y_hat - y)
    grad_a = G @ b.T
    grad_b = a.T @ G
    a = a - lr * grad_a
    b = b - lr * grad_b
return concat(a.flatten(), b.flatten())

Inputs

  • x: shape (N, d_in) β€” input features.
  • y: shape (N, d_out) β€” regression targets.
  • w_base: shape (d_in, d_out) β€” frozen base weights.
  • a0: shape (d_in, r) β€” initial LoRA-A (trainable).
  • b0: shape (r, d_out) β€” initial LoRA-B (trainable).
  • lr: float β€” learning rate.
  • n_steps: int β€” number of SGD steps.

Output

Returns shape (d_in*r + r*d_out,) β€” concatenation of final_a and final_b, both flattened.

Edge cases

  • n_steps=0: loop never runs; output is concat(a0.flatten(), b0.flatten()).
  • lr=0: a and b never change (gradients computed but multiplied by zero).

Hints

fine-tuning lora training

Sign in to attempt this problem and view the solution.