We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Tiny ResNet Classifier
Why this matters
A residual block by itself doesnβt classify anything. To get a full image classifier, you wrap a STACK of blocks with two bookends:
-
A stem: an initial conv (often
7x7stride 2 in full-scale ResNets, here just3x3stride 1) that lifts the raw image into feature space and sets the channel count. -
A head: pooling that collapses spatial dims, plus a final
Densethat maps features to per-class logits.
The result is the canonical CNN classifier shape:
image (H, W, 3)
β STEM: Conv β BN β ReLU
β BACKBONE: BasicBlock Γ N
β HEAD: GlobalAvgPool β Dense(num_classes)
β logits (num_classes,)
Every modern image classifier β ResNet, EfficientNet, ConvNeXt, even ViT (with patches replacing the conv stem) β fits this mold.
Global average pooling
The cheap-and-effective replacement for Flatten + Dense(huge).
For a feature map of shape (H, W, C), take the mean over the
spatial axes:
pooled = jnp.mean(features, axis=(1, 2)) # batched: axes 1 and 2
# shape: (B, C)
Each channel is summarized by a single number β its mean over
space. Then a single Dense(num_classes) maps (C,) β (num_classes,).
Why is this so much better than flatten?
-
Far fewer params in the head. Flatten produces
H*W*Cunits; that timesnum_classesis enormous. - Spatial invariance: averaging treats every position equally. Empirically gives a healthy regularization effect.
- Resolution-agnostic: the headβs parameter count doesnβt depend on input size, so the same trained model runs on bigger / smaller images.
Worked walk-through
Input (4, 4, 3), num_classes=3, stem_features = 3 (matches
input channels so the residual works without a projection):
-
Add batch dim:
(1, 4, 4, 3). -
Stem:
Conv3x3(3) β BN β ReLUβ(1, 4, 4, 3). -
BasicBlock(features=3)Γ 2 β still(1, 4, 4, 3). Each blockβs residual works because input channels = features. -
Global avg pool over spatial axes (1, 2) β
(1, 3). -
Dense(num_classes=3)β(1, 3). -
reshape(-1)β(3,).
The output is the per-class logits vector. (No softmax β that happens in the loss / inference step.)
Common pitfalls
-
Pooling over the wrong axes: with batch dim, pool over
(1, 2), NOT(0, 1, 2)(would also collapse the batch). For unbatched:(0, 1). -
Forgetting
mutable=['batch_stats']: every BN in stem AND blocks depends on it. One missing flag, the whole net errors. -
Mismatched channel counts in the residual blocks: this
problem keeps
stem_featuresconstant through both blocks (matchingimage.shape[-1]) so the identity skip works throughout. Donβt insert a downsampling block here. - Putting the Dense BEFORE pooling: huge param count, wrong shape β a classic newbie mistake.
Problem
Implement resnet_classifier_forward(seed, image, num_classes):
-
stem_features = image.shape[-1](so block residuals work without projection). -
Module stack:
Conv3x3 β BN β ReLU β BasicBlock Γ 2 β mean over (H, W) β Dense(num_classes). -
Init/apply with batched input;
mutable=['batch_stats']. - Return logits flattened to 1-D.
Inputs:
-
seed: int. -
image: 3-D(H, W, C). -
num_classes: int (output dim).
Output: 1-D, (num_classes,).
Hints
Sign in to attempt this problem and view the solution.