hard end_to_end

ViT Classification Forward

Implement the full Vision Transformer (ViT) forward pass for image classification β€” patch embedding β†’ N transformer blocks β†’ classifier head on the [CLS] token.

This problem composes the two preceding building blocks end-to-end:

  1. Patch embedding with [CLS] token (Task 11) β€” converts an image batch (N, C, H, W) into a sequence (N, num_patches + 1, d_model) by extracting non-overlapping patches, projecting them, prepending a [CLS] token, and adding position embeddings.

  2. Pre-LN ViT encoder blocks (Task 12) β€” each block applies x = x + MHA(LN(x)) then x = x + FFN(LN(x)) with manual layer norm, multi-head self-attention, and GELU activation; no library shortcuts.

  3. Classifier head β€” project the final [CLS] representation to logits: logits = seq[:, 0, :] @ w_head.

Pipeline

seq = patch_embed_with_cls(x, w_proj, cls_token, pos_embed)
for i in range(num_blocks):
    seq = vit_block(seq, w_q_i, w_k_i, w_v_i, w_o_i, w_mlp1_i, w_mlp2_i, num_heads)
cls_repr = seq[:, 0, :]
logits   = cls_repr @ w_head
return logits

Weight packing convention

blocks_weights has shape (num_blocks, 6, d_model, d_model). For block i, blocks_weights[i, 0..5] are: [w_q, w_k, w_v, w_o, w_mlp1, w_mlp2].

Simplification: d_ff = d_model so all six weight matrices are (d_model, d_model).

Implementation constraints

  • No nn.MultiheadAttention, no F.scaled_dot_product_attention, no nn.LayerNorm β€” implement everything from scratch.
  • Manual layer norm: (x - mean) / sqrt(var + eps), eps=1e-5.
  • Manual softmax for attention scores.
  • Manual GELU: 0.5 * x * (1 + tanh(sqrt(2/Ο€) * (x + 0.044715 * xΒ³))).

References

  • Dosovitskiy et al., β€œAn Image Is Worth 16x16 Words: Transformers for Image Recognition at Scale”, ICLR 2021.

Inputs / Output

  • x: (N, C, H, W) β€” image batch.
  • w_proj: (C*P*P, d_model) β€” patch projection.
  • cls_token: (d_model,) β€” learnable [CLS] token.
  • pos_embed: (num_patches + 1, d_model) β€” position embeddings.
  • blocks_weights: (num_blocks, 6, d_model, d_model) β€” packed block weights.
  • w_head: (d_model, num_classes) β€” classifier head.
  • num_heads: int β€” number of attention heads.
  • Output: (N, num_classes) β€” classification logits.

Hints

vit classification transformer

Sign in to attempt this problem and view the solution.