We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Test-Time Augmentation Aggregation
Why this matters
A single forward pass at test time is one sample from your model’s error distribution. Test-Time Augmentation (TTA) averages predictions over several augmented versions of the same input, giving a smoother, lower-variance estimate.
Concretely:
-
Take input
x. -
Apply
Krandom augmentations:x_1, x_2, ..., x_K. -
Run the model on each:
y_1, y_2, ..., y_K. -
Average them:
y_final = mean(y_i).
Variants:
- Crop TTA (vision): five-crop or ten-crop — center + four corners + flip-equivalents.
- Flip TTA (vision): horizontal flip + original.
- Noise TTA (everything): add small Gaussian noise to inputs.
- Multi-scale TTA: resize to different scales, predict, rescale, average.
Empirical effect: 0.5-2 % gain on most ImageNet classifiers; bigger
gains for OOD inputs and small validation sets. Free at test time
(just K extra forward passes — no retraining), so it’s standard
in Kaggle solutions and many production deployment pipelines.
Why averaging logits beats voting
Three options for combining K predictions:
- Average logits (most popular): smooth, calibrated.
- Average softmax probabilities: similar but slightly different behavior near saturation.
- Majority vote of argmax: throws away confidence info.
Average logits is the default unless you have a specific reason otherwise.
RNG: jax.random.split
Each augmentation needs its own random source. The JAX-canonical way:
base_key = jax.random.PRNGKey(seed)
aug_keys = jax.random.split(base_key, num_augs)
# aug_keys: shape (num_augs, 2) — one PRNGKey per aug.
for i in range(num_augs):
noise = 0.01 * jax.random.normal(aug_keys[i], x.shape)
x_aug = x + noise
y_i = model.apply(params, x_aug)
Without splitting, every aug would draw from the same key (NOT different randomness!) and you’d get identical augmentations, defeating the point.
When TTA helps vs. hurts
Helps:
- High-variance models (small training sets).
- OOD inputs.
- Inference where test-time compute is cheap.
Doesn’t help (much):
- Already-overfit models with very confident predictions.
- Tasks where the augmentation breaks semantics (e.g., flipping digits in MNIST — ‘6’ becomes ‘9’).
Hurts:
- Real-time systems where K extra passes blow the latency budget.
- When the model wasn’t trained with that augmentation type — the augmentations may push it OOD.
Problem
Build nn.Dense(1). Init with PRNGKey(seed) (split off a separate
key for augs). For each of num_augs aug-keys (from
jax.random.split(aug_rng, num_augs)):
-
Sample noise:
0.01 * jax.random.normal(key, x.shape). -
x_aug = x + noise. -
Compute
pred_i = model.apply({"params": params}, x_aug).reshape(-1).mean().
Average all K predictions and return as 1-D (1,).
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D). -
num_augs: float (cast to int).
Output: 1-D (1,) — [mean_of_per_aug_means].
Hints
Sign in to attempt this problem and view the solution.