We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
ViT with Mixup Augmentation
Extend a ViT training step with Mixup data augmentation.
Mixup (Zhang et al., 2017) improves regularization by linearly interpolating between pairs of training examples and their labels. Instead of training on a single example, you train on a convex combination of two:
x_mixed = λ * x_a + (1 - λ) * x_b
y_mixed = λ * y_a + (1 - λ) * y_b # soft one-hot targets
The mix coefficient λ is sampled from Beta(α, α), where α=0.4 is the canonical value. Because Beta(α, α) is symmetric and peaks near 0 and 1 for small α, most mixed examples stay close to a clean example.
Pipeline
-
Sample λ ~ Beta(α, α) using
seed(one scalar, same for the entire batch). -
Permute the batch:
idx = torch.randperm(N, generator=Generator().manual_seed(seed+1)). -
Mix inputs:
x_mixed = λ * x + (1 - λ) * x[idx]. -
Mix labels (one-hot):
y_mixed = λ * one_hot(y) + (1 - λ) * one_hot(y[idx]). -
Forward through the ViT classifier on
x_mixed→logits. -
Soft-label CE loss:
-mean(sum(y_mixed * log_softmax(logits))). - Manual backward (chain rule through every layer, same as the previous task).
- SGD update and return all updated weights concatenated flat.
Forward (identical to the previous ViT task)
patches = extract_patches(x_mixed) # (N, num_patches, C·P²)
patches = patches @ w_proj # (N, num_patches, d)
seq = cat([cls_token, patches], dim=1) + pos_embed # (N, T, d)
for each block:
norm1 = layer_norm(seq)
Q,K,V = norm1 @ w_q/k/v
attn = softmax(Q @ K.T / sqrt(d_head))
seq = seq + (attn @ V) @ w_o # residual 1
norm2 = layer_norm(seq)
seq = seq + gelu(norm2 @ w_mlp1) @ w_mlp2 # residual 2
cls_repr = seq[:, 0, :]
logits = cls_repr @ w_head
Backward
The only change from hard-label CE is the gradient at the logits:
dlogits = (softmax(logits) - y_mixed) / N
From dlogits, the backward chain through all ViT weights proceeds exactly as before.
Output
Returns a single flat tensor of all updated weights:
[w_proj_flat, cls_token_flat, pos_embed_flat, blocks_weights_flat, w_head_flat]
Implementation constraints
-
No
loss.backward()— implement the backward pass by hand. -
Use
torch._standard_gammawith a seededtorch.Generatorfor Beta sampling, ortorch.distributions.Beta(alpha, alpha).sample()(seeding the global RNG first). - Cache all intermediate activations during the forward pass; the backward needs them.
References
- Zhang et al., “mixup: Beyond Empirical Risk Minimization”, ICLR 2018.
- Dosovitskiy et al., “An Image Is Worth 16x16 Words”, ICLR 2021.
Hints
Sign in to attempt this problem and view the solution.