We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Train ViT Classifier End-to-End
Implement one full ViT training step β forward pass, softmax cross-entropy loss, manual backpropagation through every layer, and an SGD weight update.
This problem composes everything from Tasks 11β13:
- Patch embedding (Task 11) β image β patch sequence + [CLS] + pos embed.
- Pre-LN ViT encoder blocks (Task 12) β N transformer blocks.
- Classifier head (Task 13) β [CLS] token β logits.
-
Softmax CE loss β
loss = -mean(log p[y]). - Manual backward pass β chain rule through every layer.
-
SGD update β
w β w β lr * dwfor every parameter.
Forward (identical to vit_classify)
patches = extract_patches(x) # (N, PΒ², CΒ·PΒ²)
patches = patches @ w_proj # (N, PΒ², d)
seq = cat([cls_token, patches], dim=1) + pos_embed # (N, T, d), T = PΒ²+1
for each block:
norm1 = layer_norm(seq)
Q,K,V = norm1 @ w_q, norm1 @ w_k, norm1 @ w_v
attn = softmax(Q @ K.T / sqrt(d_head))
seq = seq + (attn @ V) @ w_o # residual 1
norm2 = layer_norm(seq)
seq = seq + gelu(norm2 @ w_mlp1) @ w_mlp2 # residual 2
cls_repr = seq[:, 0, :]
logits = cls_repr @ w_head
Backward (manual chain rule)
Start at the loss:
probs = softmax(logits) # (N, num_classes)
dlogits = (probs - one_hot(y)) / N # (N, num_classes)
Classifier head:
dw_head = cls_repr.T @ dlogits # (d, num_classes)
dcls_repr = dlogits @ w_head.T # (N, d)
Propagate dcls_repr back through blocks in REVERSE order. At each block, dseq flows around each residual and through each sub-layer:
FFN sub-layer (reverse):
# seq_post_attn is the seq value before the FFN residual add
d_ffn_out = dseq # gradient through the residual add
# backward through w_mlp2
dw_mlp2 = gelu_out.T @ d_ffn_out @ w_mlp2.T ... etc.
Attention sub-layer (reverse): backward through w_o, softmax attn, Q/K/V projections, then layer_norm.
Patch embed / cls / pos (after all blocks):
dw_proj = patches_raw.T @ dpatches
dcls_token = dseq[:, 0, :].sum(dim=0)
dpos_embed = dseq.sum(dim=0)
Weight packing
blocks_weights has shape (num_blocks, 6, d_model, d_model).
Index order: [w_q, w_k, w_v, w_o, w_mlp1, w_mlp2].
Simplification: d_ff = d_model β all six are (d_model, d_model).
Output
Returns a single flat tensor of all updated weights concatenated in order:
[w_proj_flat, cls_token_flat, pos_embed_flat, blocks_weights_flat, w_head_flat]
Implementation constraints
-
No
loss.backward()and nojax.gradβ implement the backward pass by hand using the chain rule. -
Manual layer norm, manual GELU, manual softmax β same as
vit_classify. - Cache all intermediate activations during the forward pass; the backward pass needs them.
References
- Dosovitskiy et al., βAn Image Is Worth 16x16 Wordsβ, ICLR 2021.
Hints
Sign in to attempt this problem and view the solution.