We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Vision Transformer
Why this matters
ViT (Dosovitskiy et al., 2020) showed that you don’t need convolutions to do image recognition — you can chop an image into patches, flatten them into a sequence of “visual tokens,” and feed them to a vanilla Transformer encoder. The whole architecture is just patch embedding + N encoder blocks + classifier head. No inductive biases beyond what the patch grid provides.
The whole “ViT story” is one trick: turn a (H, W, C) image into
a (num_patches, d_model) sequence using a strided convolution.
Once you have a sequence, the rest is the encoder block from pos 41,
stacked. This problem builds the mean-pool variant (no [CLS]
token); pos 46 builds the [CLS]-token version.
Patch embedding via strided Conv
Splitting an image into non-overlapping patches and projecting each
one through a learned linear is mathematically equivalent to a
convolution with kernel size and stride both equal to patch_size,
with VALID padding. nnx ships this as nnx.Conv:
self.patch_embed = nnx.Conv(
in_features=C,
out_features=d_model,
kernel_size=(patch_size, patch_size),
strides=patch_size,
padding="VALID",
rngs=rngs,
)
For a (H, W, C) image with patch_size=P, the output is
(H/P, W/P, d_model) — one feature vector per patch in the grid.
nnx.Conv accepts unbatched (H, W, C) input directly (it
auto-promotes internally).
Then flatten the patch grid to a sequence:
patches = self.patch_embed(image) # (Hp, Wp, d_model)
Hp, Wp, D = patches.shape
seq = patches.reshape(Hp * Wp, D) # (num_patches, d_model)
Now seq is exactly the “tokens” that an encoder eats.
Position embeddings: learned, one per patch
Patches alone don’t encode where they came from in the image. ViT adds a learned position embedding, one row per patch, same shape as the patch sequence:
self.pos_embed = nnx.Param(jnp.zeros((num_patches, d_model)))
seq = seq + self.pos_embed.value
For the test inputs, all images are (4, 4, 3) with patch_size=2,
so num_patches = 4. Compute it from image.shape in the entry
function and pass it to the module.
Mean-pool head (vs [CLS])
For classification, ViT-base used a [CLS] token (pos 46). This
problem uses the simpler mean-pool variant:
pooled = jnp.mean(x, axis=0) # (d_model,)
logits = self.head(pooled) # (num_classes,)
The mean over the sequence axis aggregates the per-patch
representations into a single global descriptor. Empirically this
works almost as well as [CLS]-token pooling for ImageNet, and
is conceptually cleaner (no special token, no +1 to the position
embedding length).
The full pipeline
image (H, W, C)
|
v
nnx.Conv(kernel=P, stride=P, valid) -> (H/P, W/P, d_model)
reshape -> (num_patches, d_model)
+ pos_embed -> (num_patches, d_model)
|
v
[EncoderBlock] x num_layers -> (num_patches, d_model)
|
v
nnx.LayerNorm -> (num_patches, d_model)
mean over axis=0 -> (d_model,)
|
v
nnx.Linear(num_classes) -> (num_classes,)
Common pitfalls
-
Wrong padding. Use
padding="VALID"so the output is(H/P, W/P, d_model)."SAME"would give a different patch grid and wrong number of tokens. -
kernel_size != strides. They must both equalpatch_sizeto get non-overlapping patches. Different values give overlapping receptive fields, which is a different architecture. -
Forgetting to flatten the patch grid.
(H/P, W/P, D)isn’t a sequence yet; the encoder expects 2-D(T, D). Reshape. -
Position embed length mismatch.
pos_embedshape must be(num_patches, d_model), not(max_T, d_model)like the LM models. Computenum_patchesfromimage.shape // patch_size. - Skipping the final LayerNorm. ViT, like GPT/BERT, applies a final LayerNorm before the head.
-
Pooling wrong axis.
jnp.mean(x, axis=0)takes the mean over patches;axis=-1would average channels. Want axis=0.
Problem
Write vit_forward(seed, image, patch_size, d_model, num_heads, d_ff, num_layers, num_classes):
-
Inner unmasked
MHAandEncoderBlock(same as pos 41). -
ViT(nnx.Module)withpatch_embed(nnx.Conv),pos_embed(nnx.Paramshape(num_patches, d_model)),blocks(nnx.Listof EncoderBlocks),ln_f(nnx.LayerNorm),head(nnx.Linear(d_model, num_classes)). - Forward: patch_embed -> reshape -> + pos -> blocks -> LN -> mean-pool axis 0 -> head.
-
Cast hyperparameters to int. Compute
num_patches = (H // P) * (W // P). -
Return logits flattened:
out.reshape(-1)(already 1-D).
Inputs:
-
seed: int (passed as float). -
image: 3-D(H, W, C). -
patch_size,d_model,num_heads,d_ff,num_layers,num_classes: ints (passed as floats).
Output: 1-D (num_classes,).
Hints
Sign in to attempt this problem and view the solution.