We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Vision Transformer (Mean-Pool Variant)
Why this matters
The Vision Transformer (Dosovitskiy et al., 2021) showed that
Transformers β designed for sequences β work spectacularly well on
images, given enough data. The trick is reshaping the image as a
sequence: cut it into patches, embed each patch, treat the result
as (num_patches, d_model) tokens, and feed it to a stock
Transformer encoder.
ViT became the foundation for CLIP, DINO, MAE, SAM, and most modern vision models. The recipe is mechanically identical to BERT once you have patch embeddings.
Architecture
image (H, W, C)
β patch embed (strided conv) β (num_patches, D)
β + learned pos embed β (num_patches, D)
β encoder block Γ N β (num_patches, D)
β LayerNorm β (num_patches, D)
β mean-pool over patches β (D,)
β Dense(num_classes) β (num_classes,)
Two halves:
- Patches β tokens β encoder: same idea as BERT, just with strided-conv patches as the βembeddingβ instead of token IDs.
- Encoder output β classifier: pool to a single vector, then a linear head to class logits.
This problem uses mean-pooling over the patch sequence to get
one vector per image. The next problem (pos 46) replaces mean-pool
with a [CLS] token β the original ViT design.
Patch embedding (refresher from pos 38)
feat = nn.Conv(features=D, kernel_size=(P, P), strides=(P, P), padding="VALID")(image)
tokens = feat.reshape(num_patches, D) # (H/P Β· W/P, D)
Strided conv with kernel = stride = patch_size is the standard
way to express βnon-overlapping linear projection of each patch.β
Worked walk-through
With image (4, 4, 3), P=2, D=8, num_layers=2, num_classes=4:
-
feat = conv(image)β(2, 2, 8). Reshape β(4, 8)(4 patches). -
pos = pos_embed[:4]β(4, 8).x = tokens + pos. - Two ViT encoder blocks (LN + MHA + FFN, Pre-LN, no causal mask).
-
x = LayerNorm(x). -
pooled = jnp.mean(x, axis=0)β(8,). -
logits = Dense(4)(pooled)β(4,).
Common pitfalls
-
Forgetting the position embedding: with no positions, the
patches are unordered β the model canβt tell βtop-leftβ from
βbottom-right.β Position embedding length =
num_patches. -
Pooling axis:
jnp.mean(x, axis=0)reduces over patches (axis 0).axis=-1reduces over D β wrong. -
Patch size not dividing image: with
padding='VALID', the last partial patch is dropped. Always pre-resize. - Stride β patch size: makes patches overlap (Swin) β not ViT.
Problem
Implement vit_forward(seed, image, patch_size, d_model, num_heads, d_ff, num_layers, num_classes):
-
Patch-embed via strided conv. Reshape to
(num_patches, D). -
Add a learned position embedding of shape
(num_patches, D). - N ViT encoder blocks (Pre-LN MHA + FFN, no causal mask).
- Final LayerNorm.
-
Mean-pool over patches β
(D,). -
Dense(num_classes)(pooled)β(num_classes,). - Return flattened.
Inputs:
-
seed: int. -
image: 3-D(H, W, C).H, Wdivisible bypatch_size. - All other args: ints.
Output: 1-D, length num_classes.
Hints
Sign in to attempt this problem and view the solution.