We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Train LoRA Adapter End-to-End
Fine-tune a LoRA adapter on a linear regression task β implemented
from scratch. No peft, no transformers; you own the forward pass,
the gradient, and the SGD update.
What is LoRA?
LoRA (Hu et al. 2021, βLoRA: Low-Rank Adaptation of Large Language Modelsβ) decomposes a weight update into two small matrices:
$$W = W_{\text{base}} + A B$$
where A is (d_in, r) and B is (r, d_out) with rank r βͺ min(d_in, d_out). This reduces trainable parameters from
d_in Γ d_out to r(d_in + d_out).
Why freeze the base weights?
w_base is kept frozen β no gradient is computed or applied to it.
Only A and B are updated. This is the whole point of LoRA: you
adapt a pre-trained model cheaply by training only the low-rank
adapter.
Forward pass
effective_w = w_base + a @ b # (d_in, d_out)
y_hat = x @ effective_w # (N, d_out)
loss = mean((y_hat - y)**2) # MSE
Gradient derivation
MSE gradient w.r.t. effective_w:
$$G = \frac{2}{N} X^\top (\hat{y} - y) \quad \text{shape } (d_{\text{in}}, d_{\text{out}})$$
Because effective_w = w_base + A B, chain rule gives:
$$\nabla_A = G B^\top, \quad \nabla_B = A^\top G$$
SGD update
a -= lr * grad_a
b -= lr * grad_b
Algorithm
a, b = a0, b0
for _ in range(n_steps):
effective_w = w_base + a @ b
y_hat = x @ effective_w
G = (2/N) * x.T @ (y_hat - y)
grad_a = G @ b.T
grad_b = a.T @ G
a = a - lr * grad_a
b = b - lr * grad_b
return concat(a.flatten(), b.flatten())
Inputs
-
x: shape(N, d_in)β input features. -
y: shape(N, d_out)β regression targets. -
w_base: shape(d_in, d_out)β frozen base weights. -
a0: shape(d_in, r)β initial LoRA-A (trainable). -
b0: shape(r, d_out)β initial LoRA-B (trainable). -
lr: float β learning rate. -
n_steps: int β number of SGD steps.
Output
Returns shape (d_in*r + r*d_out,) β concatenation of final_a and
final_b, both flattened.
Edge cases
-
n_steps=0: loop never runs; output isconcat(a0.flatten(), b0.flatten()). -
lr=0:aandbnever change (gradients computed but multiplied by zero).
Hints
Sign in to attempt this problem and view the solution.