We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
ViT Patch Embedding
Why this matters
The Vision Transformer (ViT, Dosovitskiy et al. 2021) treats an image as a sequence of fixed-size patches. To get tokens out of pixels, you:
-
Cut the
(H, W, C)image into(H/p, W/p)non-overlapping patches of sizep × p × C. -
Flatten each patch to a
p²·Cvector. -
Linear-project each patch to
d_model. -
Treat the resulting
(num_patches, d_model)array as a sequence and feed it to a Transformer.
Steps 1–3 fuse into a single op: a strided convolution with
kernel = stride = p and out_channels = d_model. The conv slides over
the image with stride p, so it never overlaps; each output spatial
location corresponds to one patch; the conv’s per-output-channel filter
is the linear projection. One layer, one (p, p, C, D) weight tensor —
that’s the whole patch embedding.
The Flax API
conv = nn.Conv(
features=D,
kernel_size=(p, p),
strides=(p, p),
padding="VALID",
)
feat = conv(image) # (H, W, C) → (H/p, W/p, D)
tokens = feat.reshape(num_patches, D)
Key arguments:
-
features=D: output channel count (= patch embedding dim). -
kernel_size=(p, p): spatial filter size. -
strides=(p, p): matches the kernel — no overlap. -
padding="VALID": no zero padding. WithH, Wdivisible byp, this gives exactlyH/p × W/ppatches.
Flax’s nn.Conv operates on (H, W, C) (or (B, H, W, C)) — channels
last. The kernel parameter shape is (p, p, C, D).
Worked example
H, W, C = 4, 4, 3
p = 2
D = 4
image = jnp.ones((H, W, C))
conv = nn.Conv(features=D, kernel_size=(p, p), strides=(p, p), padding="VALID")
params = conv.init(jax.random.PRNGKey(0), image)
feat = conv.apply(params, image) # (2, 2, 4) — 4 patches, D=4
tokens = feat.reshape(4, 4) # (num_patches, D)
A 4×4 RGB image with patch size 2 yields 4 patches × 4 features each.
Real ViT-Base uses H=W=224, p=16 → 196 patches × 768 dim.
Why a conv works
Mathematically, applying a (p, p, C, D) linear projection to every
non-overlapping (p, p, C) patch IS a strided conv with that kernel
and stride. PyTorch’s ViT codebase calls this layer
nn.Conv2d(in_channels=C, out_channels=D, kernel_size=p, stride=p).
Flax is the same with channels-last.
You could implement patchification with reshape + matmul — but the conv formulation is one line and gets you the GPU-optimised path.
Common pitfalls
-
padding='SAME': givesH/p × W/poutputs only whenH, Ware multiples ofp; otherwise pads zeros and you get extra patches. Use'VALID'for the standard ViT. -
HorWnot divisible byp: with'VALID'padding, the right / bottom pixels are dropped. Pre-resize images to a multiple ofp. -
stride != kernel_size: makes patches overlap (or skip) — that’s not ViT, it’s a different architecture (Swin uses overlapping windows). -
Forgetting the reshape: after
conv, the output is(H/p, W/p, D). You usually want a flat sequence(num_patches, D)for downstream attention.
Problem
Implement patch_embed_forward(seed, image, patch_size, d_model):
-
Cast
patch_size,d_modeltoint. -
Build
nn.Conv(features=D, kernel_size=(p, p), strides=(p, p), padding="VALID"). -
Init with
jax.random.PRNGKey(seed), apply onimage. -
Reshape the
(H/p, W/p, D)output to(num_patches, D). - Return flattened.
Inputs:
-
seed: int. -
image: 3-D float(H, W, C).HandWdivisible bypatch_size. -
patch_size: int p. -
d_model: int D.
Output: 1-D array of length (H/p · W/p) · D.
Hints
Sign in to attempt this problem and view the solution.