medium primitives

ViT Patch Embedding

Why this matters

The Vision Transformer (ViT, Dosovitskiy et al. 2021) treats an image as a sequence of fixed-size patches. To get tokens out of pixels, you:

  1. Cut the (H, W, C) image into (H/p, W/p) non-overlapping patches of size p × p × C.
  2. Flatten each patch to a p²·C vector.
  3. Linear-project each patch to d_model.
  4. Treat the resulting (num_patches, d_model) array as a sequence and feed it to a Transformer.

Steps 1–3 fuse into a single op: a strided convolution with kernel = stride = p and out_channels = d_model. The conv slides over the image with stride p, so it never overlaps; each output spatial location corresponds to one patch; the conv’s per-output-channel filter is the linear projection. One layer, one (p, p, C, D) weight tensor — that’s the whole patch embedding.

The Flax API

conv = nn.Conv(
    features=D,
    kernel_size=(p, p),
    strides=(p, p),
    padding="VALID",
)
feat = conv(image)        # (H, W, C) → (H/p, W/p, D)
tokens = feat.reshape(num_patches, D)

Key arguments:

  • features=D: output channel count (= patch embedding dim).
  • kernel_size=(p, p): spatial filter size.
  • strides=(p, p): matches the kernel — no overlap.
  • padding="VALID": no zero padding. With H, W divisible by p, this gives exactly H/p × W/p patches.

Flax’s nn.Conv operates on (H, W, C) (or (B, H, W, C)) — channels last. The kernel parameter shape is (p, p, C, D).

Worked example

H, W, C = 4, 4, 3
p = 2
D = 4

image = jnp.ones((H, W, C))
conv = nn.Conv(features=D, kernel_size=(p, p), strides=(p, p), padding="VALID")
params = conv.init(jax.random.PRNGKey(0), image)
feat = conv.apply(params, image)        # (2, 2, 4) — 4 patches, D=4
tokens = feat.reshape(4, 4)             # (num_patches, D)

A 4×4 RGB image with patch size 2 yields 4 patches × 4 features each. Real ViT-Base uses H=W=224, p=16196 patches × 768 dim.

Why a conv works

Mathematically, applying a (p, p, C, D) linear projection to every non-overlapping (p, p, C) patch IS a strided conv with that kernel and stride. PyTorch’s ViT codebase calls this layer nn.Conv2d(in_channels=C, out_channels=D, kernel_size=p, stride=p). Flax is the same with channels-last.

You could implement patchification with reshape + matmul — but the conv formulation is one line and gets you the GPU-optimised path.

Common pitfalls

  • padding='SAME': gives H/p × W/p outputs only when H, W are multiples of p; otherwise pads zeros and you get extra patches. Use 'VALID' for the standard ViT.
  • H or W not divisible by p: with 'VALID' padding, the right / bottom pixels are dropped. Pre-resize images to a multiple of p.
  • stride != kernel_size: makes patches overlap (or skip) — that’s not ViT, it’s a different architecture (Swin uses overlapping windows).
  • Forgetting the reshape: after conv, the output is (H/p, W/p, D). You usually want a flat sequence (num_patches, D) for downstream attention.

Problem

Implement patch_embed_forward(seed, image, patch_size, d_model):

  1. Cast patch_size, d_model to int.
  2. Build nn.Conv(features=D, kernel_size=(p, p), strides=(p, p), padding="VALID").
  3. Init with jax.random.PRNGKey(seed), apply on image.
  4. Reshape the (H/p, W/p, D) output to (num_patches, D).
  5. Return flattened.

Inputs:

  • seed: int.
  • image: 3-D float (H, W, C). H and W divisible by patch_size.
  • patch_size: int p.
  • d_model: int D.

Output: 1-D array of length (H/p · W/p) · D.

Hints

flax vit patch-embedding conv

Sign in to attempt this problem and view the solution.