medium primitives

Implement GroupNorm

Why this matters

GroupNorm sits between LayerNorm and BatchNorm: split channels into G groups and normalize within each group. With G=1 it equals LayerNorm (over channels); with G=C (one channel per group) it equals InstanceNorm.

GroupNorm became popular when ResNets needed to train at small batch sizes (e.g., on detection/segmentation models). BatchNorm degrades badly with small batches; GroupNorm is batch-independent like LayerNorm but groups channels, which retains useful structure per filter.

Math

For input x with shape (..., C) and G groups (C % G == 0):

  1. Reshape so the last axis splits into (G, C/G).
  2. Compute mean and variance across the per-group channel axis.
  3. Normalize, reshape back to (..., C).
  4. Apply per-channel γ, β.
leading_dims, C = x.shape[:-1], x.shape[-1]
x_g = x.reshape(*leading_dims, G, C // G)
μ   = mean(x_g, axis=-1, keepdims=True)
σ²  = mean((x_g - μ)², axis=-1, keepdims=True)
x̂_g = (x_g - μ) / sqrt(σ² + ε)
x̂   = x̂_g.reshape(*leading_dims, C)
y   = γ ⊙ x̂ + β

Note: γ, β are PER-CHANNEL (shape (C,)), not per-group. The grouping only affects WHERE statistics are computed, not the affine transform.

Worked example

Input x = [1, 2, 3, 4, 5, 6] (shape (6,)), G=2:

x_g = [[1, 2, 3], [4, 5, 6]]   # shape (2, 3)
means: [2, 5]
vars:  [(1+0+1)/3, (1+0+1)/3] = [2/3, 2/3]
x_hat (per group): each group becomes [-1.225, 0, 1.225] approximately
flatten: [-1.225, 0, 1.225, -1.225, 0, 1.225]
γ=1, β=0 → identity, so y == x_hat.

When to choose GroupNorm

  • Small batch training: BatchNorm degrades; GroupNorm is independent of batch size.
  • Fine-tuning under varying batch sizes: BatchNorm running stats mismatch dev/eval; GroupNorm doesn’t have running stats.
  • Style transfer / generation: per-instance normalization tends to work better than per-batch.

GroupNorm is the default in many segmentation models (e.g., DETR uses it).

Common pitfalls

  • C not divisible by G: assert C % G == 0 or pick groups compatible with channels (e.g., G=32 for C=64, 128, 256, ...).
  • Reshape order: split into (G, C/G) not (C/G, G). The first C/G channels go in group 0, next C/G in group 1, etc.
  • Per-channel γ, β: shape (C,), applied AFTER reshape back. Don’t apply γ, β before the reshape — that’s harmless but unconventional.

Problem

Implement MyGroupNorm(num_groups, eps):

  1. Read C = x.shape[-1]. Assert C % num_groups == 0.
  2. gamma = self.param("gamma", ones, (C,)), beta = self.param("beta", zeros, (C,)).
  3. Reshape x to (*leading, G, C/G).
  4. Mean and variance over the last axis (C/G).
  5. Normalize, reshape back.
  6. Apply gamma * x_hat + beta.

Tests pass 1-D x (no leading dims) — the reshape becomes x.reshape(G, C/G).

Inputs:

  • seed: float (cast to int).
  • x: 1-D JAX array.
  • num_groups: float (cast to int).
  • eps: float.

Output: 1-D array (flattened).

Hints

flax groupnorm self-param

Sign in to attempt this problem and view the solution.