hard primitives

ResNet Bottleneck Block

Why this matters

The basic block (pos 49) doesn’t scale. At ResNet-50 widths (features=2048), two consecutive 3x3 convs at full width are crushingly expensive: each conv costs 9 * features^2 parameters, so each block burns ~75M params. ResNet-152 with all basic blocks would be impractical.

The bottleneck block (introduced in the same paper) cuts that cost by an order of magnitude with a 1x1 β†’ 3x3 β†’ 1x1 sandwich. The 1x1 convs squeeze channels DOWN, the 3x3 operates on the cheaper representation, and a final 1x1 expands them back UP. Same expressiveness in practice, fraction of the FLOPs.

Every ResNet-50/101/152, every Faster R-CNN backbone, and a long list of detection / segmentation models that came after β€” they’re all stacks of bottleneck blocks.

The sandwich

x β†’ Conv1x1(mid)   β†’ BN β†’ ReLU       (squeeze: out β†’ mid)
  β†’ Conv3x3(mid)   β†’ BN β†’ ReLU       (process at lower width)
  β†’ Conv1x1(out)   β†’ BN β†’ (+ x)      (expand: mid β†’ out, then residual)
  β†’ ReLU β†’ output

Why 1x1 convs? A 1x1 conv at every spatial position is just a linear projection of the channel vector β€” it changes channel count without touching neighbors. Cheap (mid * out params) and differentiable; the perfect β€œchannel resizer.”

For out=256, mid=64: a basic 3x3 conv costs 9 * 256^2 = 590K params; the bottleneck triple costs 1*256*64 + 9*64*64 + 1*64*256 β‰ˆ 70K β€” about 8x cheaper, with the same input / output shape.

Identity shortcut

As with the basic block, this problem assumes the input channels already match out_features. Real ResNets add a 1x1 projection on the residual when channels don’t match (the β€œ1c” variant); we skip that here for clarity.

Worked walk-through

Input (4, 4, 4), mid_features=2, out_features=4:

  1. x_b = x[None, ...] β†’ (1, 4, 4, 4).
  2. Conv1x1(2) β†’ (1, 4, 4, 2) β€” channels squeezed.
  3. BN β†’ ReLU.
  4. Conv3x3(2, padding='SAME') β†’ (1, 4, 4, 2) β€” spatial mixing at low width.
  5. BN β†’ ReLU.
  6. Conv1x1(4) β†’ (1, 4, 4, 4) β€” channels expanded back.
  7. BN, then + x_b (the residual matches now), then ReLU.
  8. Reshape to 1-D.

Three Convs, three BNs, two intermediate ReLUs, ONE final ReLU (after the add).

Common pitfalls

  • No ReLU on the final BN before the add: correct! The add happens BEFORE the final ReLU. Putting a ReLU between the third BN and the residual add is a different (worse) architecture.
  • Forgetting to expand back to out_features: leaves the residual sum dimensionally wrong β€” a runtime shape error.
  • Using padding='VALID' on the 3x3: shrinks spatial size, breaks the residual add. Use 'SAME' so spatial dims are preserved.
  • Forgetting mutable=['batch_stats']: BN with use_running_average=False writes to batch_stats. Same as pos 49.

Problem

Implement resnet_bottleneck_forward(seed, x, mid_features, out_features):

  1. BottleneckBlock nn.Module with mid_features and out_features fields.
  2. Conv1x1(mid) β†’ BN β†’ ReLU β†’ Conv3x3(mid) β†’ BN β†’ ReLU β†’ Conv1x1(out) β†’ BN β†’ (+ x) β†’ ReLU.
  3. Init with batched input; apply with mutable=['batch_stats'].
  4. Return flattened.

Inputs:

  • seed: int.
  • x: 3-D (H, W, C) with C == out_features.
  • mid_features: int (intermediate width β€” usually out / 4).
  • out_features: int (output channels).

Output: 1-D flattened.

Hints

flax resnet bottleneck batchnorm

Sign in to attempt this problem and view the solution.