We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Patch Embedding with CLS Token
Implement patch embedding with a [CLS] token β the input stage of a Vision Transformer (ViT) that converts raw images into a sequence a transformer encoder can process.
Why patches?
Transformers operate on sequences of vectors. For images, the simplest
sequence is a bag of non-overlapping rectangular patches. Each PΓP patch
(across all C channels) is flattened into a C*P*P vector and linearly
projected to d_model. This is identical to a Conv2d(C, d_model, kernel_size=P, stride=P) but implemented as reshape + matmul β no convolution
needed.
[CLS] token
Following BERT, ViT prepends a learnable [CLS] token to the patch sequence. After all transformer layers, the final representation at the [CLS] position is fed to the classifier head. Because attention is global, this token aggregates information from every patch.
Position embeddings
Transformers are permutation-equivariant: without position information the
model cannot distinguish patch order. ViT uses learned absolute position
embeddings of shape (num_patches + 1, d_model) β one vector per position
in the sequence (including the [CLS] slot at index 0). These are simply added
to the sequence after prepending the [CLS] token.
Pipeline
Given image batch x of shape (N, C, H, W), projection weights w_proj
of shape (C*P*P, d_model), a [CLS] token vector of shape (d_model,), and
position embeddings of shape (num_patches + 1, d_model):
-
Extract patches (no Conv2d!): Reshape
xfrom(N, C, H, W)β(N, C, H_p, P, W_p, P)β permute to(N, H_p, W_p, C, P, P)β reshape to(N, num_patches, C*P*P). WhereH_p = H // P,W_p = W // P,num_patches = H_p * W_p. -
Project:
patches @ w_projβ(N, num_patches, d_model). -
Prepend [CLS]: broadcast
cls_tokento(N, 1, d_model)and concatenate:seq = concat([cls_broadcast, patches], dim=1)β(N, num_patches + 1, d_model). -
Add position embeddings:
seq + pos_embed(broadcast over batch) β(N, num_patches + 1, d_model).
Reference
Dosovitskiy et al., βAn Image Is Worth 16x16 Words: Transformers for Image Recognition at Scaleβ, ICLR 2021.
Inputs / Output
-
x: shape(N, C, H, W)β batch of images.Pdivides bothHandW. -
w_proj: shape(C*P*P, d_model)β flat-patch projection matrix. -
cls_token: shape(d_model,)β learned [CLS] embedding vector. -
pos_embed: shape(num_patches + 1, d_model)β learned position embeddings. -
Output: shape
(N, num_patches + 1, d_model).
Hints
Sign in to attempt this problem and view the solution.