We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Implement Group Norm
Implement Group Normalization (Wu & He, 2018 — “Group Normalization”).
Group Normalization was designed to fix BatchNorm’s sensitivity to small batch sizes. Instead of normalizing across the batch dimension, it normalizes across groups of channels within each sample — making statistics independent of batch size.
Math:
$$\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}$$
where $\mu$ and $\sigma^2$ are computed across each group’s channels and spatial dims.
Algorithm:
-
Reshape
xfrom(N, C, H, W)to(N, G, C/G, H, W)whereG = num_groups. -
Compute mean and variance over the last 3 axes (the group’s channels + H + W).
Use
dim=(2, 3, 4), keepdim=Truein PyTorch, oraxis=(2,3,4), keepdims=Truein JAX. -
Normalize:
(x - mean) / sqrt(var + eps). -
Reshape back to
(N, C, H, W).
Special cases:
-
num_groups=1is equivalent to Layer Normalization (normalizes all channels + spatial together). -
num_groups=Cis equivalent to Instance Normalization (normalizes per-channel spatial dims). - GroupNorm sits between these on the locality spectrum.
No learnable affine (γ, β) for this version — pure normalization.
Inputs:
-
x: tensor of shape(N, C, H, W).num_groupsmust divideC. -
num_groups: int — number of groups to split channels into. -
eps: float — small constant for numerical stability (default1e-5).
Output: same-shape tensor (N, C, H, W) — per-group normalized.
Hints
Sign in to attempt this problem and view the solution.