hard end_to_end

Train ViT Classifier End-to-End

Implement one full ViT training step β€” forward pass, softmax cross-entropy loss, manual backpropagation through every layer, and an SGD weight update.

This problem composes everything from Tasks 11–13:

  1. Patch embedding (Task 11) β€” image β†’ patch sequence + [CLS] + pos embed.
  2. Pre-LN ViT encoder blocks (Task 12) β€” N transformer blocks.
  3. Classifier head (Task 13) β€” [CLS] token β†’ logits.
  4. Softmax CE loss β€” loss = -mean(log p[y]).
  5. Manual backward pass β€” chain rule through every layer.
  6. SGD update β€” w ← w βˆ’ lr * dw for every parameter.

Forward (identical to vit_classify)

patches = extract_patches(x)                         # (N, PΒ², CΒ·PΒ²)
patches = patches @ w_proj                           # (N, PΒ², d)
seq = cat([cls_token, patches], dim=1) + pos_embed   # (N, T, d),  T = PΒ²+1
for each block:
    norm1  = layer_norm(seq)
    Q,K,V  = norm1 @ w_q, norm1 @ w_k, norm1 @ w_v
    attn   = softmax(Q @ K.T / sqrt(d_head))
    seq    = seq + (attn @ V) @ w_o          # residual 1
    norm2  = layer_norm(seq)
    seq    = seq + gelu(norm2 @ w_mlp1) @ w_mlp2  # residual 2
cls_repr = seq[:, 0, :]
logits   = cls_repr @ w_head

Backward (manual chain rule)

Start at the loss:

probs   = softmax(logits)                # (N, num_classes)
dlogits = (probs - one_hot(y)) / N       # (N, num_classes)

Classifier head:

dw_head    = cls_repr.T @ dlogits        # (d, num_classes)
dcls_repr  = dlogits @ w_head.T          # (N, d)

Propagate dcls_repr back through blocks in REVERSE order. At each block, dseq flows around each residual and through each sub-layer:

FFN sub-layer (reverse):

# seq_post_attn is the seq value before the FFN residual add
d_ffn_out = dseq                          # gradient through the residual add
# backward through w_mlp2
dw_mlp2   = gelu_out.T @ d_ffn_out @ w_mlp2.T ... etc.

Attention sub-layer (reverse): backward through w_o, softmax attn, Q/K/V projections, then layer_norm.

Patch embed / cls / pos (after all blocks):

dw_proj    = patches_raw.T @ dpatches
dcls_token = dseq[:, 0, :].sum(dim=0)
dpos_embed = dseq.sum(dim=0)

Weight packing

blocks_weights has shape (num_blocks, 6, d_model, d_model). Index order: [w_q, w_k, w_v, w_o, w_mlp1, w_mlp2]. Simplification: d_ff = d_model β€” all six are (d_model, d_model).

Output

Returns a single flat tensor of all updated weights concatenated in order:

[w_proj_flat, cls_token_flat, pos_embed_flat, blocks_weights_flat, w_head_flat]

Implementation constraints

  • No loss.backward() and no jax.grad β€” implement the backward pass by hand using the chain rule.
  • Manual layer norm, manual GELU, manual softmax β€” same as vit_classify.
  • Cache all intermediate activations during the forward pass; the backward pass needs them.

References

  • Dosovitskiy et al., β€œAn Image Is Worth 16x16 Words”, ICLR 2021.

Hints

vit classification training

Sign in to attempt this problem and view the solution.