We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
ResNet Basic Block
Why this matters
Before ResNet (2015), depth was a curse: stacking more layers made networks HARDER to train, not easier. Gradients vanished, optimization plateaued, and a 56-layer plain net underperformed a 20-layer one.
ResNet’s basic block solved this with a single, elegant trick: the
residual (skip) connection. Instead of asking each block to learn
H(x) directly, we ask it to learn H(x) - x and ADD x back at
the end:
out = relu(F(x) + x) # F is the residual function
The identity mapping is now “free” — if the optimal output is x,
the block just has to drive F(x) → 0. Gradients flow through the
+ x skip undisturbed, so depth no longer kills training. After
ResNet, networks went from “20 layers is brave” to “152 layers is
Tuesday.”
The basic block recipe
x → Conv3x3(features, stride=1) → BN → ReLU
→ Conv3x3(features, stride=1) → BN → (+ x) → ReLU → out
Two 3x3 convs, two batch norms, ONE non-linearity between them, then
add x and apply the final ReLU. The identity skip works as-is when
input and output channels match (no projection shortcut needed). Real
ResNets add a 1x1 projection conv when channel counts differ — this
problem keeps things simple by assuming C_in == features.
BatchNorm with mutable state
Flax’s nn.BatchNorm has running mean/variance that need their own
variable collection (batch_stats). When use_running_average=False
(training mode), the layer reads the BATCH stats and updates the
running stats in-place. That mutation requires mutable=... in
apply:
variables = model.init(rng, x_b) # x_b has batch dim
params = variables['params']
batch_stats = variables['batch_stats']
out, _ = model.apply(
{'params': params, 'batch_stats': batch_stats},
x_b,
mutable=['batch_stats'], # required for BN updates
)
Without mutable=['batch_stats'], BN raises because it tries to
write to a read-only collection. (See pos 17 if you implemented BN
by hand.)
Worked walk-through
Input x shape (4, 4, 4) — 4x4 spatial, 4 channels — and features=4.
-
Add batch dim:
x_b = x[None, ...]→(1, 4, 4, 4). -
nn.Conv(features=4, kernel_size=(3, 3), padding='SAME')(x_b)→(1, 4, 4, 4). -
nn.BatchNorm(use_running_average=False)→ normalize across batch. -
nn.relu→ ReLU. -
Second Conv3x3 + BN →
(1, 4, 4, 4). -
Add the (un-batched-dim’d) residual:
+ x_b. -
Final ReLU →
(1, 4, 4, 4). - Reshape to 1-D for the test harness.
Common pitfalls
-
Forgetting the residual: it’s the whole point.
relu(F(x))without+ xis just a plain CNN; you’ve discarded the ResNet trick. - Putting ReLU AFTER the second BN before the add: that’s the pre-activation variant, not the original. Keep the second ReLU AFTER the addition.
-
Forgetting
mutable=['batch_stats']: BN withuse_running_average=FalseMUST update its stats; withoutmutable, you get a “trying to write to a frozen collection” error. -
Channel mismatch: this problem assumes
C_in == features. Real ResNets project when channels differ (a 1x1 conv on the residual). Don’t add that here — it’d change the test outputs.
Problem
Implement resnet_basic_forward(seed, x, features):
-
Build a
BasicBlocknn.Modulewithfeaturesas a field. -
Inside
@nn.compact: Conv3x3-BN-ReLU-Conv3x3-BN-add-ReLU. -
Init with a batched input (
x[None, ...]); apply withmutable=['batch_stats']. - Return the output flattened to 1-D.
Inputs:
-
seed: int. -
x: 3-D(H, W, C)withC == features. -
features: int (output channels).
Output: 1-D flattened.
Hints
Sign in to attempt this problem and view the solution.