hard primitives

NNX ViT with CLS Token

Why this matters

The previous problem mean-pooled across all patch tokens to get a global descriptor. The original ViT (and BERT before it) uses a different trick: prepend a learnable [CLS] (“classification”) token to the sequence, run the transformer, and read off the FIRST output position. The CLS token has its own row of weights and is learned to attend to whatever it needs from the patches in order to predict the class.

Why bother? Two reasons:

  1. Inductive bias toward a single summary. The model has one dedicated position whose only job is “describe the whole input.” Mean-pool spreads that responsibility across all tokens.
  2. Compatibility. Almost every published ViT and BERT checkpoint uses the CLS pattern, so loading pretrained weights pretty much requires it.

The architecture difference vs pos 45 is three small edits:

  1. Add self.cls_token = nnx.Param(jnp.zeros((1, d_model))) — a single learnable row.
  2. pos_embed is now length num_patches + 1 (one extra row for the CLS slot).
  3. Concatenate CLS to the patch sequence BEFORE adding position embeddings, then read x[0] (the CLS row) at the end and pass through the classifier.

Why does the CLS token work?

Self-attention is permutation-invariant over its inputs (the position embedding is what breaks symmetry). When you prepend a learnable token, the encoder learns to USE it: across many training steps, the CLS row’s gradient pushes toward representations useful for the classification objective, and the attention layers learn to “deposit” relevant per-patch information at position 0.

Concretely, on every layer, the CLS row queries every patch and integrates the attention-weighted info into its own representation. By the final layer, it’s a learned-pool over the whole image.

Worked sketch

class ViTCLS(nnx.Module):
    def __init__(self, in_channels, patch_size, num_patches, d_model,
                 num_heads, d_ff, num_layers, num_classes, rngs):
        self.patch_embed = nnx.Conv(
            in_features=in_channels, out_features=d_model,
            kernel_size=(patch_size, patch_size),
            strides=patch_size, padding="VALID", rngs=rngs,
        )
        self.cls_token = nnx.Param(jnp.zeros((1, d_model)))
        self.pos_embed = nnx.Param(jnp.zeros((num_patches + 1, d_model)))
        self.blocks = nnx.List([
            EncoderBlock(d_model, num_heads, d_ff, rngs=rngs)
            for _ in range(num_layers)
        ])
        self.ln_f = nnx.LayerNorm(d_model, rngs=rngs)
        self.head = nnx.Linear(d_model, num_classes, rngs=rngs)

    def __call__(self, image):
        patches = self.patch_embed(image)        # (Hp, Wp, d_model)
        Hp, Wp, D = patches.shape
        seq = patches.reshape(Hp * Wp, D)        # (num_patches, d_model)
        x = jnp.concatenate([self.cls_token.value, seq], axis=0)
                                                 # (num_patches + 1, d_model)
        x = x + self.pos_embed.value
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        cls_out = x[0]                           # (d_model,)
        return self.head(cls_out)                # (num_classes,)

Five attribute changes vs pos 45 (patch_embed, cls_token, pos_embed, blocks, ln_f, head); two __call__ changes (concat, read [0] instead of mean).

A subtle point: zero-init vs random-init for CLS

Initializing cls_token to zeros is fine in practice — once you add the position embedding (also zeros here, but in real ViT it’s a truncated normal), the all-zero degeneracy is broken. In real implementations both cls_token and pos_embed are typically initialized with jax.random.normal and small std (stddev=0.02). For this problem, zeros are fine and reproducible.

Common pitfalls

  • Forgetting + 1 in pos_embed length. The position embedding must match the post-concatenation sequence length, which is num_patches + 1. Off-by-one is the most common bug here.
  • Concatenating after adding position embeddings. Then the CLS token has no position embedding (or the wrong one) — the position grid is for num_patches + 1 slots, including the CLS row. Concatenate first, then add pos.
  • Reading the wrong position at the end. The CLS token is prepended at index 0 in concat, so the final classifier input is x[0], not x[-1] or jnp.mean(x, axis=0).
  • Storing cls_token as plain attribute or wrong shape. It must be nnx.Param (trainable) and shape (1, d_model) so concatenate works.
  • Putting cls_token second instead of first. Convention is first; the matching weights are loaded that way.

Problem

Write vit_cls_forward(seed, image, patch_size, d_model, num_heads, d_ff, num_layers, num_classes):

  1. Same MHA and EncoderBlock as pos 41/45.
  2. ViTCLS(nnx.Module) with patch_embed (nnx.Conv), cls_token (nnx.Param(jnp.zeros((1, d_model)))), pos_embed (nnx.Param(jnp.zeros((num_patches + 1, d_model)))), blocks (nnx.List), ln_f, head.
  3. Forward: patch embed -> reshape -> concatenate [cls_token.value, seq] at axis 0 -> + pos_embed -> blocks -> LN -> read x[0] -> head.
  4. Cast hyperparameters to int. num_patches = (H // P) * (W // P).
  5. Return logits flattened.

Inputs:

  • seed: int (passed as float).
  • image: 3-D (H, W, C).
  • patch_size, d_model, num_heads, d_ff, num_layers, num_classes: ints (passed as floats).

Output: 1-D (num_classes,).

Hints

flax nnx vit cls-token vision-transformer architecture

Sign in to attempt this problem and view the solution.