medium end_to_end

ViT Encoder Block

Implement a pre-LN Vision Transformer encoder block — the repeating unit that stacks to form the full ViT backbone (Dosovitskiy et al., 2020).

Pre-LN vs post-LN

The original 2017 Transformer (Vaswani et al.) applies LayerNorm after the residual addition (post-LN):

out = LayerNorm(x + SubLayer(x))

Modern architectures — GPT-2, ViT, and most LLMs — instead apply LayerNorm before the sub-layer (pre-LN):

out = x + SubLayer(LayerNorm(x))

Pre-LN places the skip path through completely unnormalized residuals, which keeps gradient magnitudes more stable during training and removes the need for careful warm-up schedules.

Block structure (pre-LN)

Given input x of shape (N, T, d_model):

  1. Attention half: norm1 = LayerNorm(x) attn_out = MHA(norm1) — multi-head self-attention (no causal mask) x = x + attn_out

  2. FFN half: norm2 = LayerNorm(x) mlp_out = GELU(norm2 @ w_mlp1) @ w_mlp2 x = x + mlp_out

The residual is added outside the LayerNorm in both halves.

Details

  • LayerNorm: (x - mean) / sqrt(var + eps) over the last dim, eps=1e-5. No learned γ/β.
  • Multi-head attention: standard scaled dot-product, split d_model into num_heads heads of size d_head = d_model // num_heads. Q = norm1 @ w_q, K = norm1 @ w_k, V = norm1 @ w_v. No causal mask — ViT uses bidirectional attention. Output projection: concat @ w_o.
  • GELU: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))).
  • No nn.MultiheadAttention, no F.scaled_dot_product_attention, no nn.LayerNorm — implement everything from scratch.

References

  • Dosovitskiy et al., “An Image Is Worth 16x16 Words”, ICLR 2021 — pre-LN ViT.
  • Vaswani et al., “Attention Is All You Need”, NeurIPS 2017 — original post-LN.

This block in context

This block composes directly with Patch Embedding with CLS Token (which precedes it) and a linear classification head (which follows). Stack N of these blocks and you have the full ViT encoder.

Inputs / Output

  • x: (N, T, d_model) — sequence of patch embeddings (+ CLS token).
  • w_q, w_k, w_v, w_o: (d_model, d_model) — attention projections.
  • w_mlp1: (d_model, d_ff) — FFN up-projection.
  • w_mlp2: (d_ff, d_model) — FFN down-projection.
  • num_heads: int — d_model divisible by num_heads.
  • Output: (N, T, d_model).

Hints

vit transformer block

Sign in to attempt this problem and view the solution.