hard primitives

NNX Vision Transformer

Why this matters

ViT (Dosovitskiy et al., 2020) showed that you don’t need convolutions to do image recognition — you can chop an image into patches, flatten them into a sequence of “visual tokens,” and feed them to a vanilla Transformer encoder. The whole architecture is just patch embedding + N encoder blocks + classifier head. No inductive biases beyond what the patch grid provides.

The whole “ViT story” is one trick: turn a (H, W, C) image into a (num_patches, d_model) sequence using a strided convolution. Once you have a sequence, the rest is the encoder block from pos 41, stacked. This problem builds the mean-pool variant (no [CLS] token); pos 46 builds the [CLS]-token version.

Patch embedding via strided Conv

Splitting an image into non-overlapping patches and projecting each one through a learned linear is mathematically equivalent to a convolution with kernel size and stride both equal to patch_size, with VALID padding. nnx ships this as nnx.Conv:

self.patch_embed = nnx.Conv(
    in_features=C,
    out_features=d_model,
    kernel_size=(patch_size, patch_size),
    strides=patch_size,
    padding="VALID",
    rngs=rngs,
)

For a (H, W, C) image with patch_size=P, the output is (H/P, W/P, d_model) — one feature vector per patch in the grid. nnx.Conv accepts unbatched (H, W, C) input directly (it auto-promotes internally).

Then flatten the patch grid to a sequence:

patches = self.patch_embed(image)        # (Hp, Wp, d_model)
Hp, Wp, D = patches.shape
seq = patches.reshape(Hp * Wp, D)        # (num_patches, d_model)

Now seq is exactly the “tokens” that an encoder eats.

Position embeddings: learned, one per patch

Patches alone don’t encode where they came from in the image. ViT adds a learned position embedding, one row per patch, same shape as the patch sequence:

self.pos_embed = nnx.Param(jnp.zeros((num_patches, d_model)))
seq = seq + self.pos_embed.value

For the test inputs, all images are (4, 4, 3) with patch_size=2, so num_patches = 4. Compute it from image.shape in the entry function and pass it to the module.

Mean-pool head (vs [CLS])

For classification, ViT-base used a [CLS] token (pos 46). This problem uses the simpler mean-pool variant:

pooled = jnp.mean(x, axis=0)             # (d_model,)
logits = self.head(pooled)               # (num_classes,)

The mean over the sequence axis aggregates the per-patch representations into a single global descriptor. Empirically this works almost as well as [CLS]-token pooling for ImageNet, and is conceptually cleaner (no special token, no +1 to the position embedding length).

The full pipeline

image (H, W, C)
  |
  v
nnx.Conv(kernel=P, stride=P, valid)   -> (H/P, W/P, d_model)
  reshape                             -> (num_patches, d_model)
  + pos_embed                         -> (num_patches, d_model)
  |
  v
[EncoderBlock] x num_layers           -> (num_patches, d_model)
  |
  v
nnx.LayerNorm                         -> (num_patches, d_model)
  mean over axis=0                    -> (d_model,)
  |
  v
nnx.Linear(num_classes)               -> (num_classes,)

Common pitfalls

  • Wrong padding. Use padding="VALID" so the output is (H/P, W/P, d_model). "SAME" would give a different patch grid and wrong number of tokens.
  • kernel_size != strides. They must both equal patch_size to get non-overlapping patches. Different values give overlapping receptive fields, which is a different architecture.
  • Forgetting to flatten the patch grid. (H/P, W/P, D) isn’t a sequence yet; the encoder expects 2-D (T, D). Reshape.
  • Position embed length mismatch. pos_embed shape must be (num_patches, d_model), not (max_T, d_model) like the LM models. Compute num_patches from image.shape // patch_size.
  • Skipping the final LayerNorm. ViT, like GPT/BERT, applies a final LayerNorm before the head.
  • Pooling wrong axis. jnp.mean(x, axis=0) takes the mean over patches; axis=-1 would average channels. Want axis=0.

Problem

Write vit_forward(seed, image, patch_size, d_model, num_heads, d_ff, num_layers, num_classes):

  1. Inner unmasked MHA and EncoderBlock (same as pos 41).
  2. ViT(nnx.Module) with patch_embed (nnx.Conv), pos_embed (nnx.Param shape (num_patches, d_model)), blocks (nnx.List of EncoderBlocks), ln_f (nnx.LayerNorm), head (nnx.Linear(d_model, num_classes)).
  3. Forward: patch_embed -> reshape -> + pos -> blocks -> LN -> mean-pool axis 0 -> head.
  4. Cast hyperparameters to int. Compute num_patches = (H // P) * (W // P).
  5. Return logits flattened: out.reshape(-1) (already 1-D).

Inputs:

  • seed: int (passed as float).
  • image: 3-D (H, W, C).
  • patch_size, d_model, num_heads, d_ff, num_layers, num_classes: ints (passed as floats).

Output: 1-D (num_classes,).

Hints

flax nnx vit vision-transformer patch-embedding architecture

Sign in to attempt this problem and view the solution.