medium primitives

Implement Group Norm

Implement Group Normalization (Wu & He, 2018 — “Group Normalization”).

Group Normalization was designed to fix BatchNorm’s sensitivity to small batch sizes. Instead of normalizing across the batch dimension, it normalizes across groups of channels within each sample — making statistics independent of batch size.

Math:

$$\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}$$

where $\mu$ and $\sigma^2$ are computed across each group’s channels and spatial dims.

Algorithm:

  1. Reshape x from (N, C, H, W) to (N, G, C/G, H, W) where G = num_groups.
  2. Compute mean and variance over the last 3 axes (the group’s channels + H + W). Use dim=(2, 3, 4), keepdim=True in PyTorch, or axis=(2,3,4), keepdims=True in JAX.
  3. Normalize: (x - mean) / sqrt(var + eps).
  4. Reshape back to (N, C, H, W).

Special cases:

  • num_groups=1 is equivalent to Layer Normalization (normalizes all channels + spatial together).
  • num_groups=C is equivalent to Instance Normalization (normalizes per-channel spatial dims).
  • GroupNorm sits between these on the locality spectrum.

No learnable affine (γ, β) for this version — pure normalization.

Inputs:

  • x: tensor of shape (N, C, H, W). num_groups must divide C.
  • num_groups: int — number of groups to split channels into.
  • eps: float — small constant for numerical stability (default 1e-5).

Output: same-shape tensor (N, C, H, W) — per-group normalized.

Hints

normalization cnn

Sign in to attempt this problem and view the solution.