hard primitives

ResNet Basic Block

Why this matters

Before ResNet (2015), depth was a curse: stacking more layers made networks HARDER to train, not easier. Gradients vanished, optimization plateaued, and a 56-layer plain net underperformed a 20-layer one.

ResNet’s basic block solved this with a single, elegant trick: the residual (skip) connection. Instead of asking each block to learn H(x) directly, we ask it to learn H(x) - x and ADD x back at the end:

out = relu(F(x) + x)        # F is the residual function

The identity mapping is now “free” — if the optimal output is x, the block just has to drive F(x) → 0. Gradients flow through the + x skip undisturbed, so depth no longer kills training. After ResNet, networks went from “20 layers is brave” to “152 layers is Tuesday.”

The basic block recipe

x → Conv3x3(features, stride=1) → BN → ReLU
  → Conv3x3(features, stride=1) → BN → (+ x) → ReLU → out

Two 3x3 convs, two batch norms, ONE non-linearity between them, then add x and apply the final ReLU. The identity skip works as-is when input and output channels match (no projection shortcut needed). Real ResNets add a 1x1 projection conv when channel counts differ — this problem keeps things simple by assuming C_in == features.

BatchNorm with mutable state

Flax’s nn.BatchNorm has running mean/variance that need their own variable collection (batch_stats). When use_running_average=False (training mode), the layer reads the BATCH stats and updates the running stats in-place. That mutation requires mutable=... in apply:

variables = model.init(rng, x_b)               # x_b has batch dim
params = variables['params']
batch_stats = variables['batch_stats']
out, _ = model.apply(
    {'params': params, 'batch_stats': batch_stats},
    x_b,
    mutable=['batch_stats'],                   # required for BN updates
)

Without mutable=['batch_stats'], BN raises because it tries to write to a read-only collection. (See pos 17 if you implemented BN by hand.)

Worked walk-through

Input x shape (4, 4, 4) — 4x4 spatial, 4 channels — and features=4.

  1. Add batch dim: x_b = x[None, ...](1, 4, 4, 4).
  2. nn.Conv(features=4, kernel_size=(3, 3), padding='SAME')(x_b)(1, 4, 4, 4).
  3. nn.BatchNorm(use_running_average=False) → normalize across batch.
  4. nn.relu → ReLU.
  5. Second Conv3x3 + BN → (1, 4, 4, 4).
  6. Add the (un-batched-dim’d) residual: + x_b.
  7. Final ReLU → (1, 4, 4, 4).
  8. Reshape to 1-D for the test harness.

Common pitfalls

  • Forgetting the residual: it’s the whole point. relu(F(x)) without + x is just a plain CNN; you’ve discarded the ResNet trick.
  • Putting ReLU AFTER the second BN before the add: that’s the pre-activation variant, not the original. Keep the second ReLU AFTER the addition.
  • Forgetting mutable=['batch_stats']: BN with use_running_average=False MUST update its stats; without mutable, you get a “trying to write to a frozen collection” error.
  • Channel mismatch: this problem assumes C_in == features. Real ResNets project when channels differ (a 1x1 conv on the residual). Don’t add that here — it’d change the test outputs.

Problem

Implement resnet_basic_forward(seed, x, features):

  1. Build a BasicBlock nn.Module with features as a field.
  2. Inside @nn.compact: Conv3x3-BN-ReLU-Conv3x3-BN-add-ReLU.
  3. Init with a batched input (x[None, ...]); apply with mutable=['batch_stats'].
  4. Return the output flattened to 1-D.

Inputs:

  • seed: int.
  • x: 3-D (H, W, C) with C == features.
  • features: int (output channels).

Output: 1-D flattened.

Hints

flax resnet residual batchnorm

Sign in to attempt this problem and view the solution.