We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX ViT with CLS Token
Why this matters
The previous problem mean-pooled across all patch tokens to get a
global descriptor. The original ViT (and BERT before it) uses a
different trick: prepend a learnable [CLS] (“classification”) token
to the sequence, run the transformer, and read off the FIRST output
position. The CLS token has its own row of weights and is learned to
attend to whatever it needs from the patches in order to predict the
class.
Why bother? Two reasons:
- Inductive bias toward a single summary. The model has one dedicated position whose only job is “describe the whole input.” Mean-pool spreads that responsibility across all tokens.
- Compatibility. Almost every published ViT and BERT checkpoint uses the CLS pattern, so loading pretrained weights pretty much requires it.
The architecture difference vs pos 45 is three small edits:
-
Add
self.cls_token = nnx.Param(jnp.zeros((1, d_model)))— a single learnable row. -
pos_embedis now lengthnum_patches + 1(one extra row for the CLS slot). -
Concatenate CLS to the patch sequence BEFORE adding position
embeddings, then read
x[0](the CLS row) at the end and pass through the classifier.
Why does the CLS token work?
Self-attention is permutation-invariant over its inputs (the position embedding is what breaks symmetry). When you prepend a learnable token, the encoder learns to USE it: across many training steps, the CLS row’s gradient pushes toward representations useful for the classification objective, and the attention layers learn to “deposit” relevant per-patch information at position 0.
Concretely, on every layer, the CLS row queries every patch and integrates the attention-weighted info into its own representation. By the final layer, it’s a learned-pool over the whole image.
Worked sketch
class ViTCLS(nnx.Module):
def __init__(self, in_channels, patch_size, num_patches, d_model,
num_heads, d_ff, num_layers, num_classes, rngs):
self.patch_embed = nnx.Conv(
in_features=in_channels, out_features=d_model,
kernel_size=(patch_size, patch_size),
strides=patch_size, padding="VALID", rngs=rngs,
)
self.cls_token = nnx.Param(jnp.zeros((1, d_model)))
self.pos_embed = nnx.Param(jnp.zeros((num_patches + 1, d_model)))
self.blocks = nnx.List([
EncoderBlock(d_model, num_heads, d_ff, rngs=rngs)
for _ in range(num_layers)
])
self.ln_f = nnx.LayerNorm(d_model, rngs=rngs)
self.head = nnx.Linear(d_model, num_classes, rngs=rngs)
def __call__(self, image):
patches = self.patch_embed(image) # (Hp, Wp, d_model)
Hp, Wp, D = patches.shape
seq = patches.reshape(Hp * Wp, D) # (num_patches, d_model)
x = jnp.concatenate([self.cls_token.value, seq], axis=0)
# (num_patches + 1, d_model)
x = x + self.pos_embed.value
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
cls_out = x[0] # (d_model,)
return self.head(cls_out) # (num_classes,)
Five attribute changes vs pos 45 (patch_embed, cls_token,
pos_embed, blocks, ln_f, head); two __call__ changes
(concat, read [0] instead of mean).
A subtle point: zero-init vs random-init for CLS
Initializing cls_token to zeros is fine in practice — once you add
the position embedding (also zeros here, but in real ViT it’s a
truncated normal), the all-zero degeneracy is broken. In real
implementations both cls_token and pos_embed are typically
initialized with jax.random.normal and small std (stddev=0.02).
For this problem, zeros are fine and reproducible.
Common pitfalls
-
Forgetting
+ 1inpos_embedlength. The position embedding must match the post-concatenation sequence length, which isnum_patches + 1. Off-by-one is the most common bug here. -
Concatenating after adding position embeddings. Then the CLS
token has no position embedding (or the wrong one) — the position
grid is for
num_patches + 1slots, including the CLS row. Concatenate first, then add pos. -
Reading the wrong position at the end. The CLS token is
prepended at index 0 in
concat, so the final classifier input isx[0], notx[-1]orjnp.mean(x, axis=0). -
Storing
cls_tokenas plain attribute or wrong shape. It must bennx.Param(trainable) and shape(1, d_model)soconcatenateworks. -
Putting
cls_tokensecond instead of first. Convention is first; the matching weights are loaded that way.
Problem
Write vit_cls_forward(seed, image, patch_size, d_model, num_heads, d_ff, num_layers, num_classes):
-
Same
MHAandEncoderBlockas pos 41/45. -
ViTCLS(nnx.Module)withpatch_embed(nnx.Conv),cls_token(nnx.Param(jnp.zeros((1, d_model)))),pos_embed(nnx.Param(jnp.zeros((num_patches + 1, d_model)))),blocks(nnx.List),ln_f,head. -
Forward: patch embed -> reshape -> concatenate
[cls_token.value, seq]at axis 0 -> + pos_embed -> blocks -> LN -> readx[0]-> head. -
Cast hyperparameters to int.
num_patches = (H // P) * (W // P). - Return logits flattened.
Inputs:
-
seed: int (passed as float). -
image: 3-D(H, W, C). -
patch_size,d_model,num_heads,d_ff,num_layers,num_classes: ints (passed as floats).
Output: 1-D (num_classes,).
Hints
Sign in to attempt this problem and view the solution.