hard primitives

Vision Transformer (Mean-Pool Variant)

Why this matters

The Vision Transformer (Dosovitskiy et al., 2021) showed that Transformers β€” designed for sequences β€” work spectacularly well on images, given enough data. The trick is reshaping the image as a sequence: cut it into patches, embed each patch, treat the result as (num_patches, d_model) tokens, and feed it to a stock Transformer encoder.

ViT became the foundation for CLIP, DINO, MAE, SAM, and most modern vision models. The recipe is mechanically identical to BERT once you have patch embeddings.

Architecture

image (H, W, C)
β†’ patch embed (strided conv)  β†’ (num_patches, D)
β†’ + learned pos embed         β†’ (num_patches, D)
β†’ encoder block Γ— N           β†’ (num_patches, D)
β†’ LayerNorm                   β†’ (num_patches, D)
β†’ mean-pool over patches      β†’ (D,)
β†’ Dense(num_classes)          β†’ (num_classes,)

Two halves:

  • Patches β†’ tokens β†’ encoder: same idea as BERT, just with strided-conv patches as the β€œembedding” instead of token IDs.
  • Encoder output β†’ classifier: pool to a single vector, then a linear head to class logits.

This problem uses mean-pooling over the patch sequence to get one vector per image. The next problem (pos 46) replaces mean-pool with a [CLS] token β€” the original ViT design.

Patch embedding (refresher from pos 38)

feat = nn.Conv(features=D, kernel_size=(P, P), strides=(P, P), padding="VALID")(image)
tokens = feat.reshape(num_patches, D)        # (H/P Β· W/P, D)

Strided conv with kernel = stride = patch_size is the standard way to express β€œnon-overlapping linear projection of each patch.”

Worked walk-through

With image (4, 4, 3), P=2, D=8, num_layers=2, num_classes=4:

  1. feat = conv(image) β†’ (2, 2, 8). Reshape β†’ (4, 8) (4 patches).
  2. pos = pos_embed[:4] β†’ (4, 8). x = tokens + pos.
  3. Two ViT encoder blocks (LN + MHA + FFN, Pre-LN, no causal mask).
  4. x = LayerNorm(x).
  5. pooled = jnp.mean(x, axis=0) β†’ (8,).
  6. logits = Dense(4)(pooled) β†’ (4,).

Common pitfalls

  • Forgetting the position embedding: with no positions, the patches are unordered β€” the model can’t tell β€œtop-left” from β€œbottom-right.” Position embedding length = num_patches.
  • Pooling axis: jnp.mean(x, axis=0) reduces over patches (axis 0). axis=-1 reduces over D β€” wrong.
  • Patch size not dividing image: with padding='VALID', the last partial patch is dropped. Always pre-resize.
  • Stride β‰  patch size: makes patches overlap (Swin) β€” not ViT.

Problem

Implement vit_forward(seed, image, patch_size, d_model, num_heads, d_ff, num_layers, num_classes):

  1. Patch-embed via strided conv. Reshape to (num_patches, D).
  2. Add a learned position embedding of shape (num_patches, D).
  3. N ViT encoder blocks (Pre-LN MHA + FFN, no causal mask).
  4. Final LayerNorm.
  5. Mean-pool over patches β†’ (D,).
  6. Dense(num_classes)(pooled) β†’ (num_classes,).
  7. Return flattened.

Inputs:

  • seed: int.
  • image: 3-D (H, W, C). H, W divisible by patch_size.
  • All other args: ints.

Output: 1-D, length num_classes.

Hints

flax vit vision-transformer image-classification

Sign in to attempt this problem and view the solution.