hard primitives

Tiny ResNet Classifier

Why this matters

A residual block by itself doesn’t classify anything. To get a full image classifier, you wrap a STACK of blocks with two bookends:

  1. A stem: an initial conv (often 7x7 stride 2 in full-scale ResNets, here just 3x3 stride 1) that lifts the raw image into feature space and sets the channel count.
  2. A head: pooling that collapses spatial dims, plus a final Dense that maps features to per-class logits.

The result is the canonical CNN classifier shape:

image (H, W, 3)
    β†’ STEM:    Conv β†’ BN β†’ ReLU
    β†’ BACKBONE: BasicBlock Γ— N
    β†’ HEAD:    GlobalAvgPool β†’ Dense(num_classes)
    β†’ logits (num_classes,)

Every modern image classifier β€” ResNet, EfficientNet, ConvNeXt, even ViT (with patches replacing the conv stem) β€” fits this mold.

Global average pooling

The cheap-and-effective replacement for Flatten + Dense(huge). For a feature map of shape (H, W, C), take the mean over the spatial axes:

pooled = jnp.mean(features, axis=(1, 2))   # batched: axes 1 and 2
# shape: (B, C)

Each channel is summarized by a single number β€” its mean over space. Then a single Dense(num_classes) maps (C,) β†’ (num_classes,).

Why is this so much better than flatten?

  • Far fewer params in the head. Flatten produces H*W*C units; that times num_classes is enormous.
  • Spatial invariance: averaging treats every position equally. Empirically gives a healthy regularization effect.
  • Resolution-agnostic: the head’s parameter count doesn’t depend on input size, so the same trained model runs on bigger / smaller images.

Worked walk-through

Input (4, 4, 3), num_classes=3, stem_features = 3 (matches input channels so the residual works without a projection):

  1. Add batch dim: (1, 4, 4, 3).
  2. Stem: Conv3x3(3) β†’ BN β†’ ReLU β†’ (1, 4, 4, 3).
  3. BasicBlock(features=3) Γ— 2 β†’ still (1, 4, 4, 3). Each block’s residual works because input channels = features.
  4. Global avg pool over spatial axes (1, 2) β†’ (1, 3).
  5. Dense(num_classes=3) β†’ (1, 3).
  6. reshape(-1) β†’ (3,).

The output is the per-class logits vector. (No softmax β€” that happens in the loss / inference step.)

Common pitfalls

  • Pooling over the wrong axes: with batch dim, pool over (1, 2), NOT (0, 1, 2) (would also collapse the batch). For unbatched: (0, 1).
  • Forgetting mutable=['batch_stats']: every BN in stem AND blocks depends on it. One missing flag, the whole net errors.
  • Mismatched channel counts in the residual blocks: this problem keeps stem_features constant through both blocks (matching image.shape[-1]) so the identity skip works throughout. Don’t insert a downsampling block here.
  • Putting the Dense BEFORE pooling: huge param count, wrong shape β€” a classic newbie mistake.

Problem

Implement resnet_classifier_forward(seed, image, num_classes):

  1. stem_features = image.shape[-1] (so block residuals work without projection).
  2. Module stack: Conv3x3 β†’ BN β†’ ReLU β†’ BasicBlock Γ— 2 β†’ mean over (H, W) β†’ Dense(num_classes).
  3. Init/apply with batched input; mutable=['batch_stats'].
  4. Return logits flattened to 1-D.

Inputs:

  • seed: int.
  • image: 3-D (H, W, C).
  • num_classes: int (output dim).

Output: 1-D, (num_classes,).

Hints

flax resnet classifier global-avg-pool

Sign in to attempt this problem and view the solution.