We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Vision Transformer with [CLS] Token
Why this matters
The original ViT (and BERT before it) doesn’t mean-pool. It uses a
[CLS] token: a single learned (D,) vector PREPENDED to the
patch sequence. After the encoder, the final representation at
position 0 (the [CLS] slot) is fed to the classifier.
The [CLS] token starts as a learned vector that isn’t tied to any
patch. As attention runs, every block lets [CLS] attend to all
patches and lets all patches see [CLS]. By the final layer, the
model has learned to summarise the whole image into the [CLS]
position via attention. The classifier reads only that one vector.
Why bother? In practice mean-pool and [CLS] give very similar
accuracy on small models. The [CLS] design is preferred when you
want one slot to be the “summary” — useful for retrieval (CLIP),
distillation, and downstream tasks that pluck a single vector.
Architecture
image (H, W, C)
→ patch embed → (num_patches, D)
→ prepend learned [CLS] → (num_patches + 1, D)
→ + learned pos embed (length+1) → (num_patches + 1, D)
→ encoder block × N → (num_patches + 1, D)
→ LayerNorm → (num_patches + 1, D)
→ take position 0 (the CLS) → (D,)
→ Dense(num_classes) → (num_classes,)
Two new pieces vs pos 45:
-
cls_token = self.param("cls_token", normal(0.02), (1, D))— a learned(1, D)vector. -
seq = jnp.concatenate([cls_token, tokens], axis=0)— sequence length is nownum_patches + 1. -
Position embedding length is
num_patches + 1to match. -
Classifier reads
x[0]— the CLS slot — instead of mean-pooling.
Worked walk-through
With image (4, 4, 3), P=2, D=8:
-
Patch conv →
(2, 2, 8). Reshape →(4, 8). -
Concat
cls_token (1, 8) + tokens (4, 8)→(5, 8). -
Add pos embed
(5, 8). -
Two encoder blocks →
(5, 8). -
LayerNorm →
(5, 8). -
cls_out = x[0]→(8,). -
Dense(num_classes)(cls_out)→(num_classes,).
Why CLS works
Self-attention is permutation-equivariant. Adding a single fixed “summary slot” gives the network a designated place to aggregate information. Every layer’s attention can copy patch info into the CLS slot via attention weights. After N layers, CLS holds a (learned) summary of the image.
The position embedding is critical: without it, [CLS] and the
patches have no inherent positions, and [CLS] would just be one
of many indistinguishable tokens.
Common pitfalls
-
Wrong CLS shape: it’s
(1, D)not(D,)— needs the leading length-1 axis to concatenate cleanly. -
Concat axis:
axis=0(the sequence axis), notaxis=-1(D). -
Position embedding length:
num_patches + 1, NOTnum_patches. Forgetting the +1 silently breaks shape matching. -
Pulling the wrong slot:
x[0], notx[-1]orjnp.mean(x). The CLS slot is at position 0 because that’s where it was prepended. -
Forgetting
LayerNormbefore classifier: in original ViT, the final LN comes AFTER the encoder stack and BEFORE the head.
Problem
Implement vit_cls_forward(seed, image, patch_size, d_model, num_heads, d_ff, num_layers, num_classes):
- Patch-embed via strided conv.
-
Prepend a learned
cls_token = self.param("cls_token", normal(0.02), (1, D)). -
Add a learned position embedding of length
num_patches + 1. - N encoder blocks.
- Final LayerNorm.
-
Take
x[0](the CLS slot). -
Dense(num_classes)(cls_out).
Inputs: same as pos 45.
Output: 1-D, length num_classes.
Hints
Sign in to attempt this problem and view the solution.