medium primitives

NNX Implement GroupNorm

Why this matters

GroupNorm sits between LayerNorm and BatchNorm: split channels into G groups and normalize within each group.

  • G = 1: equals LayerNorm (statistics over all channels per sample).
  • G = C (one channel per group): equals InstanceNorm (statistics per channel per sample).
  • Anything in between: GroupNorm proper.

GroupNorm became popular when ResNets needed small-batch training on detection / segmentation models. BatchNorm degrades sharply when batch size shrinks (the running statistics become noisy); GroupNorm is batch-independent like LayerNorm but groups channels, which retains useful per-filter structure. DETR, ConvNeXt, and many segmentation nets use it as the default.

API

Two trainable params, both per-CHANNEL (not per-group):

  • gamma shape (C,), init ones.
  • beta shape (C,), init zeros.

Plus num_groups and eps as static attributes. The grouping only affects WHERE statistics are computed, not the affine transform.

Math

For input x with shape (..., C) and G groups (C % G == 0):

  1. Reshape so the last axis splits into (G, C/G): x_g = x.reshape(*leading, G, C // G).
  2. Compute mean and variance over the last axis (within each group): mu = mean(x_g, axis=-1, keepdims=True), var = mean((x_g - mu) ** 2, axis=-1, keepdims=True).
  3. Normalize: x_hat_g = (x_g - mu) / sqrt(var + eps).
  4. Reshape back to (..., C).
  5. Apply per-channel affine: gamma * x_hat + beta.

Worked example

Input x = [1, 2, 3, 4, 5, 6] (shape (6,)), G=2:

x_g    = [[1, 2, 3], [4, 5, 6]]   # (2, 3)
means  = [2, 5]                   # per group
vars   = [2/3, 2/3]               # per group
x_hat (per group): each becomes [-1.225, 0, 1.225] approximately
flatten back: [-1.225, 0, 1.225, -1.225, 0, 1.225]
With γ=1, β=0: y == x_hat.

Worked sketch

class MyGroupNorm(nnx.Module):
    def __init__(self, num_features, num_groups, eps, rngs):
        assert num_features % num_groups == 0
        self.gamma = nnx.Param(jnp.ones((num_features,)))
        self.beta = nnx.Param(jnp.zeros((num_features,)))
        self.num_groups = num_groups
        self.num_features = num_features
        self.eps = eps

    def __call__(self, x):
        leading = x.shape[:-1]
        g = self.num_groups
        c_per_g = self.num_features // g
        x_g = x.reshape(*leading, g, c_per_g)
        mu = jnp.mean(x_g, axis=-1, keepdims=True)
        var = jnp.mean((x_g - mu) ** 2, axis=-1, keepdims=True)
        x_hat_g = (x_g - mu) / jnp.sqrt(var + self.eps)
        x_hat = x_hat_g.reshape(*leading, self.num_features)
        return self.gamma * x_hat + self.beta

When to choose GroupNorm

  • Small batch training: BatchNorm running stats become unreliable below ~16 samples; GroupNorm is independent of batch size.
  • Fine-tuning under varying batch sizes: BatchNorm running stats mismatch dev/eval; GroupNorm has no running stats.
  • Style transfer / generation: per-instance normalization (≈ InstanceNorm via G=C) helps decouple image-level statistics.

Common pitfalls

  • C not divisible by G. Assert in __init__. Picking G=32 and C ∈ {64, 128, 256, ...} is the common convention.
  • Reshape order. Split into (G, C/G), not (C/G, G). The first C/G channels go in group 0, next C/G in group 1, etc.
  • gamma / beta per-group instead of per-channel. They’re always per-channel, applied AFTER the reshape back.
  • num_groups arriving as float. Cast to int.

Problem

Write groupnorm_forward(seed, x, num_groups, eps):

  1. Define MyGroupNorm(nnx.Module) with gamma, beta as nnx.Params of shape (num_features,), plus num_groups, num_features, eps as plain attributes.
  2. __call__: reshape into (..., G, C/G), compute per-group mean and variance over axis=-1, normalize, reshape back, apply gamma * x_hat + beta.
  3. Build with nnx.Rngs(int(seed)), instantiate (num_features=x.shape[-1], num_groups=int(num_groups), eps=float(eps)), return model(x).

Tests pass 1-D x (no leading dims). Make sure the reshape uses *leading to support that path.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • num_groups: int (passed as float — cast).
  • eps: float.

Output: same shape as x.

Hints

flax nnx groupnorm reimplementation

Sign in to attempt this problem and view the solution.