medium end_to_end

ViT with Mixup Augmentation

Extend a ViT training step with Mixup data augmentation.

Mixup (Zhang et al., 2017) improves regularization by linearly interpolating between pairs of training examples and their labels. Instead of training on a single example, you train on a convex combination of two:

x_mixed = λ * x_a + (1 - λ) * x_b
y_mixed = λ * y_a + (1 - λ) * y_b        # soft one-hot targets

The mix coefficient λ is sampled from Beta(α, α), where α=0.4 is the canonical value. Because Beta(α, α) is symmetric and peaks near 0 and 1 for small α, most mixed examples stay close to a clean example.

Pipeline

  1. Sample λ ~ Beta(α, α) using seed (one scalar, same for the entire batch).
  2. Permute the batch: idx = torch.randperm(N, generator=Generator().manual_seed(seed+1)).
  3. Mix inputs: x_mixed = λ * x + (1 - λ) * x[idx].
  4. Mix labels (one-hot): y_mixed = λ * one_hot(y) + (1 - λ) * one_hot(y[idx]).
  5. Forward through the ViT classifier on x_mixedlogits.
  6. Soft-label CE loss: -mean(sum(y_mixed * log_softmax(logits))).
  7. Manual backward (chain rule through every layer, same as the previous task).
  8. SGD update and return all updated weights concatenated flat.

Forward (identical to the previous ViT task)

patches = extract_patches(x_mixed)             # (N, num_patches, C·P²)
patches = patches @ w_proj                     # (N, num_patches, d)
seq = cat([cls_token, patches], dim=1) + pos_embed   # (N, T, d)
for each block:
    norm1  = layer_norm(seq)
    Q,K,V  = norm1 @ w_q/k/v
    attn   = softmax(Q @ K.T / sqrt(d_head))
    seq    = seq + (attn @ V) @ w_o            # residual 1
    norm2  = layer_norm(seq)
    seq    = seq + gelu(norm2 @ w_mlp1) @ w_mlp2   # residual 2
cls_repr = seq[:, 0, :]
logits   = cls_repr @ w_head

Backward

The only change from hard-label CE is the gradient at the logits:

dlogits = (softmax(logits) - y_mixed) / N

From dlogits, the backward chain through all ViT weights proceeds exactly as before.

Output

Returns a single flat tensor of all updated weights:

[w_proj_flat, cls_token_flat, pos_embed_flat, blocks_weights_flat, w_head_flat]

Implementation constraints

  • No loss.backward() — implement the backward pass by hand.
  • Use torch._standard_gamma with a seeded torch.Generator for Beta sampling, or torch.distributions.Beta(alpha, alpha).sample() (seeding the global RNG first).
  • Cache all intermediate activations during the forward pass; the backward needs them.

References

  • Zhang et al., “mixup: Beyond Empirical Risk Minimization”, ICLR 2018.
  • Dosovitskiy et al., “An Image Is Worth 16x16 Words”, ICLR 2021.

Hints

vit augmentation mixup

Sign in to attempt this problem and view the solution.