We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
DeiT — Data-Efficient Image Transformer
Why this matters
The original ViT (pos 46) needed enormous datasets (JFT-300M) to beat ConvNets. DeiT (Touvron et al., 2021) closed that gap on plain ImageNet using a clever trick: a distillation token.
DeiT’s contribution: prepend not one but TWO learned tokens to the
patch sequence — [CLS] AND [DIST]. The student model is trained
with two losses:
-
Standard cross-entropy from the
[CLS]head against the ground truth label. -
Distillation loss from the
[DIST]head against a teacher model’s predictions.
The two heads share the encoder; only the heads themselves diverge. At inference time, you average the two heads’ logits (or just use one). The trick made ViT data-efficient enough to train on ImageNet from scratch.
Architecture
image (H, W, C)
→ patch embed → (num_patches, D)
→ prepend [CLS] and [DIST] → (num_patches + 2, D)
→ + learned pos embed (length+2) → (num_patches + 2, D)
→ encoder block × N → (num_patches + 2, D)
→ LayerNorm → (num_patches + 2, D)
→ cls_head(x[0]): → (num_classes,)
→ dist_head(x[1]): → (num_classes,)
→ concat → (2 · num_classes,)
Two prepended tokens, two pos-embedding slots, two separate classifier heads, output is the concatenation of both heads’ logits.
Worked walk-through
With image (4, 4, 3), P=2, D=8, num_classes=4:
-
Patch conv →
(2, 2, 8) → reshape (4, 8)— 4 patches. -
cls (1, 8); dist (1, 8); seq = concat([cls, dist, patches], axis=0)→(6, 8). -
Pos embed
(6, 8). Add. -
Two encoder blocks →
(6, 8). - LayerNorm.
-
cls_out = x[0]; dist_out = x[1]. -
cls_logits = cls_head(cls_out);dist_logits = dist_head(dist_out). -
out = jnp.concatenate([cls_logits, dist_logits], axis=0)→(8,).
Why two tokens
The CLS token’s attention pattern at training time will optimize for matching the ground truth. The DIST token’s attention pattern will optimize for matching the teacher’s distribution. Empirically these patterns differ — the DIST head learns features the CLS head doesn’t, and vice versa. Combining them ensembles two learners inside one forward pass.
Hence “data-efficient”: the teacher acts as a regulariser/source of extra signal, so you need fewer real labels.
Heads with explicit names
Both heads are Dense(num_classes). To get DIFFERENT params
(different scopes in Flax), pass name= so they’re tracked
separately:
cls_head = nn.Dense(num_classes, name="cls_head")
dist_head = nn.Dense(num_classes, name="dist_head")
Without explicit names Flax would still scope them differently
(each nn.Dense(...) call inside @nn.compact gets a fresh scope),
but explicit names make the intent obvious.
Common pitfalls
-
Wrong concat order:
[cls, dist, patches]— order MUST be consistent with the readout indicesx[0](CLS) andx[1](DIST). -
Pos embed length:
num_patches + 2, not+1(CLS) or+0. -
One shared head: there must be TWO
Denseinstances; sharing defeats the purpose. -
Concat axis on output: 1-D concat (
axis=0) joins two(num_classes,)vectors into(2 * num_classes,). Concatenating along the wrong axis gives a shape error.
Problem
Implement deit_forward(seed, image, patch_size, d_model, num_heads, d_ff, num_layers, num_classes):
-
Patch embed + reshape to
(num_patches, D). -
Prepend two learned tokens:
[CLS] (1, D)and[DIST] (1, D). -
Pos embed length
num_patches + 2. - N encoder blocks + final LayerNorm.
-
cls_logits = cls_head(x[0]);dist_logits = dist_head(x[1]). Use TWO separatenn.Dense(num_classes)modules. -
Return
concat([cls_logits, dist_logits], axis=0).
Inputs: same shape as pos 46.
Output: 1-D, length 2 * num_classes.
Hints
Sign in to attempt this problem and view the solution.