We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
ViT Encoder Block
Implement a pre-LN Vision Transformer encoder block — the repeating unit that stacks to form the full ViT backbone (Dosovitskiy et al., 2020).
Pre-LN vs post-LN
The original 2017 Transformer (Vaswani et al.) applies LayerNorm after the residual addition (post-LN):
out = LayerNorm(x + SubLayer(x))
Modern architectures — GPT-2, ViT, and most LLMs — instead apply LayerNorm before the sub-layer (pre-LN):
out = x + SubLayer(LayerNorm(x))
Pre-LN places the skip path through completely unnormalized residuals, which keeps gradient magnitudes more stable during training and removes the need for careful warm-up schedules.
Block structure (pre-LN)
Given input x of shape (N, T, d_model):
-
Attention half:
norm1 = LayerNorm(x)attn_out = MHA(norm1)— multi-head self-attention (no causal mask)x = x + attn_out -
FFN half:
norm2 = LayerNorm(x)mlp_out = GELU(norm2 @ w_mlp1) @ w_mlp2x = x + mlp_out
The residual is added outside the LayerNorm in both halves.
Details
-
LayerNorm:
(x - mean) / sqrt(var + eps)over the last dim,eps=1e-5. No learned γ/β. -
Multi-head attention: standard scaled dot-product, split
d_modelintonum_headsheads of sized_head = d_model // num_heads.Q = norm1 @ w_q,K = norm1 @ w_k,V = norm1 @ w_v. No causal mask — ViT uses bidirectional attention. Output projection:concat @ w_o. -
GELU:
0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))). -
No
nn.MultiheadAttention, noF.scaled_dot_product_attention, nonn.LayerNorm— implement everything from scratch.
References
- Dosovitskiy et al., “An Image Is Worth 16x16 Words”, ICLR 2021 — pre-LN ViT.
- Vaswani et al., “Attention Is All You Need”, NeurIPS 2017 — original post-LN.
This block in context
This block composes directly with Patch Embedding with CLS Token (which precedes it) and a linear classification head (which follows). Stack N of these blocks and you have the full ViT encoder.
Inputs / Output
-
x:(N, T, d_model)— sequence of patch embeddings (+ CLS token). -
w_q, w_k, w_v, w_o:(d_model, d_model)— attention projections. -
w_mlp1:(d_model, d_ff)— FFN up-projection. -
w_mlp2:(d_ff, d_model)— FFN down-projection. -
num_heads: int —d_modeldivisible bynum_heads. -
Output:
(N, T, d_model).
Hints
Sign in to attempt this problem and view the solution.