We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Implement GroupNorm
Why this matters
GroupNorm sits between LayerNorm and BatchNorm: split channels into G
groups and normalize within each group. With G=1 it equals LayerNorm
(over channels); with G=C (one channel per group) it equals
InstanceNorm.
GroupNorm became popular when ResNets needed to train at small batch sizes (e.g., on detection/segmentation models). BatchNorm degrades badly with small batches; GroupNorm is batch-independent like LayerNorm but groups channels, which retains useful structure per filter.
Math
For input x with shape (..., C) and G groups (C % G == 0):
-
Reshape so the last axis splits into
(G, C/G). - Compute mean and variance across the per-group channel axis.
-
Normalize, reshape back to
(..., C). - Apply per-channel γ, β.
leading_dims, C = x.shape[:-1], x.shape[-1]
x_g = x.reshape(*leading_dims, G, C // G)
μ = mean(x_g, axis=-1, keepdims=True)
σ² = mean((x_g - μ)², axis=-1, keepdims=True)
x̂_g = (x_g - μ) / sqrt(σ² + ε)
x̂ = x̂_g.reshape(*leading_dims, C)
y = γ ⊙ x̂ + β
Note: γ, β are PER-CHANNEL (shape (C,)), not per-group. The grouping
only affects WHERE statistics are computed, not the affine transform.
Worked example
Input x = [1, 2, 3, 4, 5, 6] (shape (6,)), G=2:
x_g = [[1, 2, 3], [4, 5, 6]] # shape (2, 3)
means: [2, 5]
vars: [(1+0+1)/3, (1+0+1)/3] = [2/3, 2/3]
x_hat (per group): each group becomes [-1.225, 0, 1.225] approximately
flatten: [-1.225, 0, 1.225, -1.225, 0, 1.225]
γ=1, β=0 → identity, so y == x_hat.
When to choose GroupNorm
- Small batch training: BatchNorm degrades; GroupNorm is independent of batch size.
- Fine-tuning under varying batch sizes: BatchNorm running stats mismatch dev/eval; GroupNorm doesn’t have running stats.
- Style transfer / generation: per-instance normalization tends to work better than per-batch.
GroupNorm is the default in many segmentation models (e.g., DETR uses it).
Common pitfalls
-
Cnot divisible byG: assertC % G == 0or pick groups compatible with channels (e.g.,G=32forC=64, 128, 256, ...). -
Reshape order: split into
(G, C/G)not(C/G, G). The firstC/Gchannels go in group 0, nextC/Gin group 1, etc. -
Per-channel γ, β: shape
(C,), applied AFTER reshape back. Don’t apply γ, β before the reshape — that’s harmless but unconventional.
Problem
Implement MyGroupNorm(num_groups, eps):
-
Read
C = x.shape[-1]. AssertC % num_groups == 0. -
gamma = self.param("gamma", ones, (C,)),beta = self.param("beta", zeros, (C,)). -
Reshape
xto(*leading, G, C/G). -
Mean and variance over the last axis (
C/G). - Normalize, reshape back.
-
Apply
gamma * x_hat + beta.
Tests pass 1-D x (no leading dims) — the reshape becomes
x.reshape(G, C/G).
Inputs:
-
seed: float (cast to int). -
x: 1-D JAX array. -
num_groups: float (cast to int). -
eps: float.
Output: 1-D array (flattened).
Hints
Sign in to attempt this problem and view the solution.