We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Tiny ResNet Classifier
Why this matters
A ResNet block (pos 47) is the building lego brick. A ResNet
classifier is what you get when you wire a stem (initial
feature-extraction layer), a stack of blocks, and a classifier head
around them. The pattern — stem -> stages -> global_avg_pool -> linear — is the canonical ConvNet recipe; ResNet, ResNeXt, RegNet,
ConvNeXt, EfficientNet all follow this layout, varying only the
block design.
This problem builds the smallest meaningful version: one stem
Conv-BN-ReLU, two basic blocks at a fixed channel width, global
average pooling, and a final nnx.Linear(num_classes).
The pipeline
image (H, W, C_in)
|
v
Conv3x3 -> BN -> ReLU # stem (lifts to `features` channels)
-> (H, W, features)
|
v
BasicBlock(features) x 2 # the residual stages
-> (H, W, features)
|
v
global_avg_pool over (H, W) # spatial mean -> (features,)
-> (features,)
|
v
nnx.Linear(features, num_classes) # classifier head
-> (num_classes,)
For our tests, features = 4 is fixed (we won’t pass it in). The
stem keeps the spatial dimensions the same (3x3, stride 1, SAME
padding); the blocks also preserve (H, W) (stride 1, SAME). Real
ResNets stride down at stage boundaries, but the basic block by
itself doesn’t.
Stem: lift channels from C_in to features
The image arrives with C_in channels (3 for RGB). The stem conv
maps them to features so the rest of the network can use a
consistent width:
self.stem_conv = nnx.Conv(
in_features=in_channels,
out_features=features,
kernel_size=(3, 3), strides=1, padding="SAME", rngs=rngs,
)
self.stem_bn = nnx.BatchNorm(num_features=features, rngs=rngs)
The stem is not a residual — there’s no skip because the input and output channels are different. Just Conv-BN-ReLU.
Body: stack BasicBlocks (pos 47)
Two BasicBlock(features=features, rngs=rngs) instances called in
sequence. Each preserves (H, W, features). Since this is a small
model, we keep them as named attributes (block1, block2) instead
of an nnx.List — fine for two-or-three layer toys.
Head: global average pool then Linear
Global average pooling collapses the spatial axes to a single vector:
pooled = jnp.mean(x, axis=(0, 1)) # (H, W, features) -> (features,)
logits = self.head(pooled) # (features,) -> (num_classes,)
Why GAP and not flatten + nnx.Linear?
- Translation invariance. GAP weights every spatial position equally — the head sees a global descriptor of each channel, not “what channel C looks like at position (h, w).”
-
Resolution-independence. GAP works with any
(H, W); flatten requires fixed input size to size the Linear. -
Fewer parameters.
nnx.Linear(features, num_classes)vsnnx.Linear(H * W * features, num_classes).
Worked sketch
class ResNetClassifier(nnx.Module):
def __init__(self, in_channels, features, num_classes, rngs):
self.stem_conv = nnx.Conv(in_features=in_channels, out_features=features,
kernel_size=(3, 3), strides=1, padding="SAME", rngs=rngs)
self.stem_bn = nnx.BatchNorm(num_features=features, rngs=rngs)
self.block1 = BasicBlock(features=features, rngs=rngs)
self.block2 = BasicBlock(features=features, rngs=rngs)
self.head = nnx.Linear(features, num_classes, rngs=rngs)
def __call__(self, image, use_running_average):
x = self.stem_conv(image)
x = self.stem_bn(x, use_running_average=use_running_average)
x = jax.nn.relu(x)
x = self.block1(x, use_running_average=use_running_average)
x = self.block2(x, use_running_average=use_running_average)
pooled = jnp.mean(x, axis=(0, 1))
return self.head(pooled)
Common pitfalls
-
GAP over the wrong axes. For
(H, W, features), wantaxis=(0, 1)(spatial).axis=-1would average channels, which reduces to a single scalar per pixel. -
Forgetting to thread
use_running_average. The block’s BatchNorm needs it; if you callblock(x)without the kwarg the block has no train/eval flag (the block’s call wouldn’t compile, since you wrote it to acceptuse_running_averagein pos 47). -
Stem too aggressive. Real ResNets use stride-2 at the stem.
We use stride-1 because our toy images are tiny — striding would
reduce
(H, W)below useful sizes. -
featuresmismatch between stem and blocks. The stem’sout_featuresmust equal the basic block’sfeatures— they need to match to make the residual skip work.
Problem
Write resnet_classifier_forward(seed, image, num_classes):
-
BasicBlockfrom pos 47 (Conv3x3-BN-ReLU-Conv3x3-BN-+x-ReLU). -
ResNetClassifier(nnx.Module)withstem_conv,stem_bn,block1,block2,head. -
__call__(image, use_running_average): stem (conv, bn, ReLU) -> block1 -> block2 -> mean overaxis=(0, 1)-> head. -
Use
features = 4(fixed). Apply withuse_running_average=False. -
Cast
num_classesto int. Returnout.reshape(-1).
Inputs:
-
seed: int (passed as float). -
image: 3-D(H, W, C). -
num_classes: int (passed as float).
Output: 1-D (num_classes,).
Hints
Sign in to attempt this problem and view the solution.