medium end_to_end

Patch Embedding with CLS Token

Implement patch embedding with a [CLS] token β€” the input stage of a Vision Transformer (ViT) that converts raw images into a sequence a transformer encoder can process.

Why patches?

Transformers operate on sequences of vectors. For images, the simplest sequence is a bag of non-overlapping rectangular patches. Each PΓ—P patch (across all C channels) is flattened into a C*P*P vector and linearly projected to d_model. This is identical to a Conv2d(C, d_model, kernel_size=P, stride=P) but implemented as reshape + matmul β€” no convolution needed.

[CLS] token

Following BERT, ViT prepends a learnable [CLS] token to the patch sequence. After all transformer layers, the final representation at the [CLS] position is fed to the classifier head. Because attention is global, this token aggregates information from every patch.

Position embeddings

Transformers are permutation-equivariant: without position information the model cannot distinguish patch order. ViT uses learned absolute position embeddings of shape (num_patches + 1, d_model) β€” one vector per position in the sequence (including the [CLS] slot at index 0). These are simply added to the sequence after prepending the [CLS] token.

Pipeline

Given image batch x of shape (N, C, H, W), projection weights w_proj of shape (C*P*P, d_model), a [CLS] token vector of shape (d_model,), and position embeddings of shape (num_patches + 1, d_model):

  1. Extract patches (no Conv2d!): Reshape x from (N, C, H, W) β†’ (N, C, H_p, P, W_p, P) β†’ permute to (N, H_p, W_p, C, P, P) β†’ reshape to (N, num_patches, C*P*P). Where H_p = H // P, W_p = W // P, num_patches = H_p * W_p.

  2. Project: patches @ w_proj β†’ (N, num_patches, d_model).

  3. Prepend [CLS]: broadcast cls_token to (N, 1, d_model) and concatenate: seq = concat([cls_broadcast, patches], dim=1) β†’ (N, num_patches + 1, d_model).

  4. Add position embeddings: seq + pos_embed (broadcast over batch) β†’ (N, num_patches + 1, d_model).

Reference

Dosovitskiy et al., β€œAn Image Is Worth 16x16 Words: Transformers for Image Recognition at Scale”, ICLR 2021.

Inputs / Output

  • x: shape (N, C, H, W) β€” batch of images. P divides both H and W.
  • w_proj: shape (C*P*P, d_model) β€” flat-patch projection matrix.
  • cls_token: shape (d_model,) β€” learned [CLS] embedding vector.
  • pos_embed: shape (num_patches + 1, d_model) β€” learned position embeddings.
  • Output: shape (N, num_patches + 1, d_model).

Hints

vit transformer embedding

Sign in to attempt this problem and view the solution.