hard primitives

Vision Transformer with [CLS] Token

Why this matters

The original ViT (and BERT before it) doesn’t mean-pool. It uses a [CLS] token: a single learned (D,) vector PREPENDED to the patch sequence. After the encoder, the final representation at position 0 (the [CLS] slot) is fed to the classifier.

The [CLS] token starts as a learned vector that isn’t tied to any patch. As attention runs, every block lets [CLS] attend to all patches and lets all patches see [CLS]. By the final layer, the model has learned to summarise the whole image into the [CLS] position via attention. The classifier reads only that one vector.

Why bother? In practice mean-pool and [CLS] give very similar accuracy on small models. The [CLS] design is preferred when you want one slot to be the “summary” — useful for retrieval (CLIP), distillation, and downstream tasks that pluck a single vector.

Architecture

image (H, W, C)
→ patch embed                       → (num_patches, D)
→ prepend learned [CLS]             → (num_patches + 1, D)
→ + learned pos embed (length+1)    → (num_patches + 1, D)
→ encoder block × N                 → (num_patches + 1, D)
→ LayerNorm                         → (num_patches + 1, D)
→ take position 0 (the CLS)         → (D,)
→ Dense(num_classes)                → (num_classes,)

Two new pieces vs pos 45:

  1. cls_token = self.param("cls_token", normal(0.02), (1, D)) — a learned (1, D) vector.
  2. seq = jnp.concatenate([cls_token, tokens], axis=0) — sequence length is now num_patches + 1.
  3. Position embedding length is num_patches + 1 to match.
  4. Classifier reads x[0] — the CLS slot — instead of mean-pooling.

Worked walk-through

With image (4, 4, 3), P=2, D=8:

  1. Patch conv → (2, 2, 8). Reshape → (4, 8).
  2. Concat cls_token (1, 8) + tokens (4, 8)(5, 8).
  3. Add pos embed (5, 8).
  4. Two encoder blocks → (5, 8).
  5. LayerNorm → (5, 8).
  6. cls_out = x[0](8,).
  7. Dense(num_classes)(cls_out)(num_classes,).

Why CLS works

Self-attention is permutation-equivariant. Adding a single fixed “summary slot” gives the network a designated place to aggregate information. Every layer’s attention can copy patch info into the CLS slot via attention weights. After N layers, CLS holds a (learned) summary of the image.

The position embedding is critical: without it, [CLS] and the patches have no inherent positions, and [CLS] would just be one of many indistinguishable tokens.

Common pitfalls

  • Wrong CLS shape: it’s (1, D) not (D,) — needs the leading length-1 axis to concatenate cleanly.
  • Concat axis: axis=0 (the sequence axis), not axis=-1 (D).
  • Position embedding length: num_patches + 1, NOT num_patches. Forgetting the +1 silently breaks shape matching.
  • Pulling the wrong slot: x[0], not x[-1] or jnp.mean(x). The CLS slot is at position 0 because that’s where it was prepended.
  • Forgetting LayerNorm before classifier: in original ViT, the final LN comes AFTER the encoder stack and BEFORE the head.

Problem

Implement vit_cls_forward(seed, image, patch_size, d_model, num_heads, d_ff, num_layers, num_classes):

  1. Patch-embed via strided conv.
  2. Prepend a learned cls_token = self.param("cls_token", normal(0.02), (1, D)).
  3. Add a learned position embedding of length num_patches + 1.
  4. N encoder blocks.
  5. Final LayerNorm.
  6. Take x[0] (the CLS slot).
  7. Dense(num_classes)(cls_out).

Inputs: same as pos 45.

Output: 1-D, length num_classes.

Hints

flax vit cls-token vision-transformer

Sign in to attempt this problem and view the solution.