medium primitives

Test-Time Augmentation Aggregation

Why this matters

A single forward pass at test time is one sample from your model’s error distribution. Test-Time Augmentation (TTA) averages predictions over several augmented versions of the same input, giving a smoother, lower-variance estimate.

Concretely:

  1. Take input x.
  2. Apply K random augmentations: x_1, x_2, ..., x_K.
  3. Run the model on each: y_1, y_2, ..., y_K.
  4. Average them: y_final = mean(y_i).

Variants:

  • Crop TTA (vision): five-crop or ten-crop — center + four corners + flip-equivalents.
  • Flip TTA (vision): horizontal flip + original.
  • Noise TTA (everything): add small Gaussian noise to inputs.
  • Multi-scale TTA: resize to different scales, predict, rescale, average.

Empirical effect: 0.5-2 % gain on most ImageNet classifiers; bigger gains for OOD inputs and small validation sets. Free at test time (just K extra forward passes — no retraining), so it’s standard in Kaggle solutions and many production deployment pipelines.

Why averaging logits beats voting

Three options for combining K predictions:

  1. Average logits (most popular): smooth, calibrated.
  2. Average softmax probabilities: similar but slightly different behavior near saturation.
  3. Majority vote of argmax: throws away confidence info.

Average logits is the default unless you have a specific reason otherwise.

RNG: jax.random.split

Each augmentation needs its own random source. The JAX-canonical way:

base_key = jax.random.PRNGKey(seed)
aug_keys = jax.random.split(base_key, num_augs)
# aug_keys: shape (num_augs, 2) — one PRNGKey per aug.

for i in range(num_augs):
    noise = 0.01 * jax.random.normal(aug_keys[i], x.shape)
    x_aug = x + noise
    y_i = model.apply(params, x_aug)

Without splitting, every aug would draw from the same key (NOT different randomness!) and you’d get identical augmentations, defeating the point.

When TTA helps vs. hurts

Helps:

  • High-variance models (small training sets).
  • OOD inputs.
  • Inference where test-time compute is cheap.

Doesn’t help (much):

  • Already-overfit models with very confident predictions.
  • Tasks where the augmentation breaks semantics (e.g., flipping digits in MNIST — ‘6’ becomes ‘9’).

Hurts:

  • Real-time systems where K extra passes blow the latency budget.
  • When the model wasn’t trained with that augmentation type — the augmentations may push it OOD.

Problem

Build nn.Dense(1). Init with PRNGKey(seed) (split off a separate key for augs). For each of num_augs aug-keys (from jax.random.split(aug_rng, num_augs)):

  1. Sample noise: 0.01 * jax.random.normal(key, x.shape).
  2. x_aug = x + noise.
  3. Compute pred_i = model.apply({"params": params}, x_aug).reshape(-1).mean().

Average all K predictions and return as 1-D (1,).

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D).
  • num_augs: float (cast to int).

Output: 1-D (1,)[mean_of_per_aug_means].

Hints

flax inference test-time-augmentation

Sign in to attempt this problem and view the solution.