hard primitives

DeiT — Data-Efficient Image Transformer

Why this matters

The original ViT (pos 46) needed enormous datasets (JFT-300M) to beat ConvNets. DeiT (Touvron et al., 2021) closed that gap on plain ImageNet using a clever trick: a distillation token.

DeiT’s contribution: prepend not one but TWO learned tokens to the patch sequence — [CLS] AND [DIST]. The student model is trained with two losses:

  1. Standard cross-entropy from the [CLS] head against the ground truth label.
  2. Distillation loss from the [DIST] head against a teacher model’s predictions.

The two heads share the encoder; only the heads themselves diverge. At inference time, you average the two heads’ logits (or just use one). The trick made ViT data-efficient enough to train on ImageNet from scratch.

Architecture

image (H, W, C)
→ patch embed                     → (num_patches, D)
→ prepend [CLS] and [DIST]        → (num_patches + 2, D)
→ + learned pos embed (length+2)  → (num_patches + 2, D)
→ encoder block × N               → (num_patches + 2, D)
→ LayerNorm                       → (num_patches + 2, D)
→ cls_head(x[0]):                 → (num_classes,)
→ dist_head(x[1]):                → (num_classes,)
→ concat                          → (2 · num_classes,)

Two prepended tokens, two pos-embedding slots, two separate classifier heads, output is the concatenation of both heads’ logits.

Worked walk-through

With image (4, 4, 3), P=2, D=8, num_classes=4:

  1. Patch conv → (2, 2, 8) → reshape (4, 8) — 4 patches.
  2. cls (1, 8); dist (1, 8); seq = concat([cls, dist, patches], axis=0)(6, 8).
  3. Pos embed (6, 8). Add.
  4. Two encoder blocks → (6, 8).
  5. LayerNorm.
  6. cls_out = x[0]; dist_out = x[1].
  7. cls_logits = cls_head(cls_out); dist_logits = dist_head(dist_out).
  8. out = jnp.concatenate([cls_logits, dist_logits], axis=0)(8,).

Why two tokens

The CLS token’s attention pattern at training time will optimize for matching the ground truth. The DIST token’s attention pattern will optimize for matching the teacher’s distribution. Empirically these patterns differ — the DIST head learns features the CLS head doesn’t, and vice versa. Combining them ensembles two learners inside one forward pass.

Hence “data-efficient”: the teacher acts as a regulariser/source of extra signal, so you need fewer real labels.

Heads with explicit names

Both heads are Dense(num_classes). To get DIFFERENT params (different scopes in Flax), pass name= so they’re tracked separately:

cls_head = nn.Dense(num_classes, name="cls_head")
dist_head = nn.Dense(num_classes, name="dist_head")

Without explicit names Flax would still scope them differently (each nn.Dense(...) call inside @nn.compact gets a fresh scope), but explicit names make the intent obvious.

Common pitfalls

  • Wrong concat order: [cls, dist, patches] — order MUST be consistent with the readout indices x[0] (CLS) and x[1] (DIST).
  • Pos embed length: num_patches + 2, not +1 (CLS) or +0.
  • One shared head: there must be TWO Dense instances; sharing defeats the purpose.
  • Concat axis on output: 1-D concat (axis=0) joins two (num_classes,) vectors into (2 * num_classes,). Concatenating along the wrong axis gives a shape error.

Problem

Implement deit_forward(seed, image, patch_size, d_model, num_heads, d_ff, num_layers, num_classes):

  1. Patch embed + reshape to (num_patches, D).
  2. Prepend two learned tokens: [CLS] (1, D) and [DIST] (1, D).
  3. Pos embed length num_patches + 2.
  4. N encoder blocks + final LayerNorm.
  5. cls_logits = cls_head(x[0]); dist_logits = dist_head(x[1]). Use TWO separate nn.Dense(num_classes) modules.
  6. Return concat([cls_logits, dist_logits], axis=0).

Inputs: same shape as pos 46.

Output: 1-D, length 2 * num_classes.

Hints

flax deit vision-transformer distillation

Sign in to attempt this problem and view the solution.