hard primitives

NNX Tiny ResNet Classifier

Why this matters

A ResNet block (pos 47) is the building lego brick. A ResNet classifier is what you get when you wire a stem (initial feature-extraction layer), a stack of blocks, and a classifier head around them. The pattern — stem -> stages -> global_avg_pool -> linear — is the canonical ConvNet recipe; ResNet, ResNeXt, RegNet, ConvNeXt, EfficientNet all follow this layout, varying only the block design.

This problem builds the smallest meaningful version: one stem Conv-BN-ReLU, two basic blocks at a fixed channel width, global average pooling, and a final nnx.Linear(num_classes).

The pipeline

image (H, W, C_in)
  |
  v
Conv3x3 -> BN -> ReLU                     # stem (lifts to `features` channels)
  -> (H, W, features)
  |
  v
BasicBlock(features) x 2                  # the residual stages
  -> (H, W, features)
  |
  v
global_avg_pool over (H, W)               # spatial mean -> (features,)
  -> (features,)
  |
  v
nnx.Linear(features, num_classes)         # classifier head
  -> (num_classes,)

For our tests, features = 4 is fixed (we won’t pass it in). The stem keeps the spatial dimensions the same (3x3, stride 1, SAME padding); the blocks also preserve (H, W) (stride 1, SAME). Real ResNets stride down at stage boundaries, but the basic block by itself doesn’t.

Stem: lift channels from C_in to features

The image arrives with C_in channels (3 for RGB). The stem conv maps them to features so the rest of the network can use a consistent width:

self.stem_conv = nnx.Conv(
    in_features=in_channels,
    out_features=features,
    kernel_size=(3, 3), strides=1, padding="SAME", rngs=rngs,
)
self.stem_bn = nnx.BatchNorm(num_features=features, rngs=rngs)

The stem is not a residual — there’s no skip because the input and output channels are different. Just Conv-BN-ReLU.

Body: stack BasicBlocks (pos 47)

Two BasicBlock(features=features, rngs=rngs) instances called in sequence. Each preserves (H, W, features). Since this is a small model, we keep them as named attributes (block1, block2) instead of an nnx.List — fine for two-or-three layer toys.

Head: global average pool then Linear

Global average pooling collapses the spatial axes to a single vector:

pooled = jnp.mean(x, axis=(0, 1))    # (H, W, features) -> (features,)
logits = self.head(pooled)           # (features,) -> (num_classes,)

Why GAP and not flatten + nnx.Linear?

  1. Translation invariance. GAP weights every spatial position equally — the head sees a global descriptor of each channel, not “what channel C looks like at position (h, w).”
  2. Resolution-independence. GAP works with any (H, W); flatten requires fixed input size to size the Linear.
  3. Fewer parameters. nnx.Linear(features, num_classes) vs nnx.Linear(H * W * features, num_classes).

Worked sketch

class ResNetClassifier(nnx.Module):
    def __init__(self, in_channels, features, num_classes, rngs):
        self.stem_conv = nnx.Conv(in_features=in_channels, out_features=features,
                                  kernel_size=(3, 3), strides=1, padding="SAME", rngs=rngs)
        self.stem_bn = nnx.BatchNorm(num_features=features, rngs=rngs)
        self.block1 = BasicBlock(features=features, rngs=rngs)
        self.block2 = BasicBlock(features=features, rngs=rngs)
        self.head = nnx.Linear(features, num_classes, rngs=rngs)

    def __call__(self, image, use_running_average):
        x = self.stem_conv(image)
        x = self.stem_bn(x, use_running_average=use_running_average)
        x = jax.nn.relu(x)
        x = self.block1(x, use_running_average=use_running_average)
        x = self.block2(x, use_running_average=use_running_average)
        pooled = jnp.mean(x, axis=(0, 1))
        return self.head(pooled)

Common pitfalls

  • GAP over the wrong axes. For (H, W, features), want axis=(0, 1) (spatial). axis=-1 would average channels, which reduces to a single scalar per pixel.
  • Forgetting to thread use_running_average. The block’s BatchNorm needs it; if you call block(x) without the kwarg the block has no train/eval flag (the block’s call wouldn’t compile, since you wrote it to accept use_running_average in pos 47).
  • Stem too aggressive. Real ResNets use stride-2 at the stem. We use stride-1 because our toy images are tiny — striding would reduce (H, W) below useful sizes.
  • features mismatch between stem and blocks. The stem’s out_features must equal the basic block’s features — they need to match to make the residual skip work.

Problem

Write resnet_classifier_forward(seed, image, num_classes):

  1. BasicBlock from pos 47 (Conv3x3-BN-ReLU-Conv3x3-BN-+x-ReLU).
  2. ResNetClassifier(nnx.Module) with stem_conv, stem_bn, block1, block2, head.
  3. __call__(image, use_running_average): stem (conv, bn, ReLU) -> block1 -> block2 -> mean over axis=(0, 1) -> head.
  4. Use features = 4 (fixed). Apply with use_running_average=False.
  5. Cast num_classes to int. Return out.reshape(-1).

Inputs:

  • seed: int (passed as float).
  • image: 3-D (H, W, C).
  • num_classes: int (passed as float).

Output: 1-D (num_classes,).

Hints

flax nnx resnet classifier global-avg-pool architecture

Sign in to attempt this problem and view the solution.