We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Implement GroupNorm
Why this matters
GroupNorm sits between LayerNorm and BatchNorm: split channels into
G groups and normalize within each group.
-
G = 1: equals LayerNorm (statistics over all channels per sample). -
G = C(one channel per group): equals InstanceNorm (statistics per channel per sample). - Anything in between: GroupNorm proper.
GroupNorm became popular when ResNets needed small-batch training on detection / segmentation models. BatchNorm degrades sharply when batch size shrinks (the running statistics become noisy); GroupNorm is batch-independent like LayerNorm but groups channels, which retains useful per-filter structure. DETR, ConvNeXt, and many segmentation nets use it as the default.
API
Two trainable params, both per-CHANNEL (not per-group):
-
gammashape(C,), init ones. -
betashape(C,), init zeros.
Plus num_groups and eps as static attributes. The grouping only
affects WHERE statistics are computed, not the affine transform.
Math
For input x with shape (..., C) and G groups (C % G == 0):
-
Reshape so the last axis splits into
(G, C/G):x_g = x.reshape(*leading, G, C // G). -
Compute mean and variance over the last axis (within each group):
mu = mean(x_g, axis=-1, keepdims=True),var = mean((x_g - mu) ** 2, axis=-1, keepdims=True). -
Normalize:
x_hat_g = (x_g - mu) / sqrt(var + eps). -
Reshape back to
(..., C). -
Apply per-channel affine:
gamma * x_hat + beta.
Worked example
Input x = [1, 2, 3, 4, 5, 6] (shape (6,)), G=2:
x_g = [[1, 2, 3], [4, 5, 6]] # (2, 3)
means = [2, 5] # per group
vars = [2/3, 2/3] # per group
x_hat (per group): each becomes [-1.225, 0, 1.225] approximately
flatten back: [-1.225, 0, 1.225, -1.225, 0, 1.225]
With γ=1, β=0: y == x_hat.
Worked sketch
class MyGroupNorm(nnx.Module):
def __init__(self, num_features, num_groups, eps, rngs):
assert num_features % num_groups == 0
self.gamma = nnx.Param(jnp.ones((num_features,)))
self.beta = nnx.Param(jnp.zeros((num_features,)))
self.num_groups = num_groups
self.num_features = num_features
self.eps = eps
def __call__(self, x):
leading = x.shape[:-1]
g = self.num_groups
c_per_g = self.num_features // g
x_g = x.reshape(*leading, g, c_per_g)
mu = jnp.mean(x_g, axis=-1, keepdims=True)
var = jnp.mean((x_g - mu) ** 2, axis=-1, keepdims=True)
x_hat_g = (x_g - mu) / jnp.sqrt(var + self.eps)
x_hat = x_hat_g.reshape(*leading, self.num_features)
return self.gamma * x_hat + self.beta
When to choose GroupNorm
- Small batch training: BatchNorm running stats become unreliable below ~16 samples; GroupNorm is independent of batch size.
- Fine-tuning under varying batch sizes: BatchNorm running stats mismatch dev/eval; GroupNorm has no running stats.
-
Style transfer / generation: per-instance normalization (≈
InstanceNorm via
G=C) helps decouple image-level statistics.
Common pitfalls
-
Cnot divisible byG. Assert in__init__. PickingG=32andC ∈ {64, 128, 256, ...}is the common convention. -
Reshape order. Split into
(G, C/G), not(C/G, G). The firstC/Gchannels go in group 0, nextC/Gin group 1, etc. -
gamma/betaper-group instead of per-channel. They’re always per-channel, applied AFTER the reshape back. -
num_groupsarriving as float. Cast to int.
Problem
Write groupnorm_forward(seed, x, num_groups, eps):
-
Define
MyGroupNorm(nnx.Module)withgamma,betaasnnx.Params of shape(num_features,), plusnum_groups,num_features,epsas plain attributes. -
__call__: reshape into(..., G, C/G), compute per-group mean and variance overaxis=-1, normalize, reshape back, applygamma * x_hat + beta. -
Build with
nnx.Rngs(int(seed)), instantiate (num_features=x.shape[-1],num_groups=int(num_groups),eps=float(eps)), returnmodel(x).
Tests pass 1-D x (no leading dims). Make sure the reshape uses
*leading to support that path.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
num_groups: int (passed as float — cast). -
eps: float.
Output: same shape as x.
Hints
Sign in to attempt this problem and view the solution.