hard primitives

NNX ResNet Basic Block

Why this matters

The ResNet basic block (He et al., 2015) is the most influential architectural primitive of the past decade. The trick: instead of asking each block to learn H(x) directly, ask it to learn the residual F(x) = H(x) - x, then output F(x) + x. The skip connection makes very deep nets trainable β€” gradients flow back through the identity path even when F(x) is tiny or saturated.

Same trick powers Transformer blocks (you’ve already used it: every x + Sublayer(x) is a residual). ResNet was the original.

The basic block: two Conv-BN-ReLUs with a skip

x ──> Conv3x3 ─> BN ─> ReLU ─> Conv3x3 ─> BN ─> + ─> ReLU ─> out
  β”‚                                              β”‚
  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜  (skip)

Five steps:

  1. conv1 (3x3, stride 1, SAME padding, in/out = features),
  2. bn1 over channel dim,
  3. ReLU,
  4. conv2 (3x3, stride 1, SAME padding, in/out = features),
  5. bn2 over channel dim,
  6. add the skip (x itself),
  7. final ReLU.

Note ReLU is INSIDE the residual on the first conv but the SECOND ReLU is AFTER the skip-add. That’s the original β€œv1” placement. Modern variants (PreAct ResNet, etc.) move things around; we use the canonical v1 here.

nnx makes BatchNorm easy

nnx.BatchNorm(num_features, rngs=rngs) is the built-in. Call it with use_running_average: bool to switch train/eval. In Linen you’d have to declare mutable=["batch_stats"] and manage (out, updates) return tuples β€” in nnx the running statistics mutate in place when use_running_average=False, so calling the block is just block(x, use_running_average=False).

For BatchNorm on (H, W, C) input, the channel dimension is -1 by default and the running stats live as nnx.Variables sized (C,).

Why no projection on the skip?

The β€œbasic” block assumes the input and output have the same number of channels β€” in_features == out_features == features. So the skip connection is just the input x, no transform. When the channel count changes (as in ResNet’s downsample blocks), the skip needs a 1x1 conv projection to match shapes β€” that’s the β€œdownsample” or β€œshortcut” variant. We use the simpler equal-channels case here.

Worked sketch

class BasicBlock(nnx.Module):
    def __init__(self, features, rngs):
        self.conv1 = nnx.Conv(
            in_features=features, out_features=features,
            kernel_size=(3, 3), strides=1, padding="SAME", rngs=rngs,
        )
        self.bn1 = nnx.BatchNorm(num_features=features, rngs=rngs)
        self.conv2 = nnx.Conv(
            in_features=features, out_features=features,
            kernel_size=(3, 3), strides=1, padding="SAME", rngs=rngs,
        )
        self.bn2 = nnx.BatchNorm(num_features=features, rngs=rngs)

    def __call__(self, x, use_running_average):
        identity = x
        h = self.conv1(x)
        h = self.bn1(h, use_running_average=use_running_average)
        h = jax.nn.relu(h)
        h = self.conv2(h)
        h = self.bn2(h, use_running_average=use_running_average)
        h = h + identity
        return jax.nn.relu(h)

Compare with Linen, where the same block needs a separate batch_stats variable collection and a mutable= argument at apply time. In nnx bn1 and bn2 track their own running stats as instance state.

Common pitfalls

  • Skipping the skip. Without h = h + identity, you have a plain conv stack β€” no residual, no gradient highway.
  • ReLU between BN2 and skip-add. In the v1 block the second ReLU is AFTER the skip-add, not before. Order matters.
  • Forgetting the channel check. The basic block requires C == features. If channels were different you’d need a 1x1 projection on the skip.
  • Wrong padding. "SAME" keeps spatial dims constant so the skip-add is shape-compatible. "VALID" would shrink and break.
  • Storing BN modules as plain attributes vs sub-Modules. They ARE sub-modules; assigning self.bn1 = nnx.BatchNorm(...) is correct. Nothing extra needed.

Problem

Write resnet_basic_forward(seed, x, features):

  1. BasicBlock(nnx.Module) with conv1, bn1, conv2, bn2. Each conv: kernel_size=(3, 3), strides=1, padding="SAME", in_features=features, out_features=features.
  2. __call__(x, use_running_average) does conv1 -> bn1 -> ReLU -> conv2 -> bn2 -> +skip -> ReLU.
  3. Cast features to int. Build with nnx.Rngs(int(seed)). Apply with use_running_average=False (training mode).
  4. Return out.reshape(-1).

Inputs:

  • seed: int (passed as float).
  • x: 3-D (H, W, C) channels-last; C == features.
  • features: int (passed as float).

Output: 1-D flattened.

Hints

flax nnx resnet residual batchnorm architecture

Sign in to attempt this problem and view the solution.