We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX ResNet Basic Block
Why this matters
The ResNet basic block (He et al., 2015) is the most influential
architectural primitive of the past decade. The trick: instead of
asking each block to learn H(x) directly, ask it to learn the
residual F(x) = H(x) - x, then output F(x) + x. The skip
connection makes very deep nets trainable β gradients flow back
through the identity path even when F(x) is tiny or saturated.
Same trick powers Transformer blocks (youβve already used it: every
x + Sublayer(x) is a residual). ResNet was the original.
The basic block: two Conv-BN-ReLUs with a skip
x ββ> Conv3x3 β> BN β> ReLU β> Conv3x3 β> BN β> + β> ReLU β> out
β β
ββββββββββββββββββββββββββββββββββββββββββββββββ (skip)
Five steps:
-
conv1(3x3, stride 1, SAME padding, in/out = features), -
bn1over channel dim, -
ReLU, -
conv2(3x3, stride 1, SAME padding, in/out = features), -
bn2over channel dim, -
add the skip (
xitself), -
final
ReLU.
Note ReLU is INSIDE the residual on the first conv but the SECOND ReLU is AFTER the skip-add. Thatβs the original βv1β placement. Modern variants (PreAct ResNet, etc.) move things around; we use the canonical v1 here.
nnx makes BatchNorm easy
nnx.BatchNorm(num_features, rngs=rngs) is the built-in. Call it
with use_running_average: bool to switch train/eval. In Linen youβd
have to declare mutable=["batch_stats"] and manage (out, updates)
return tuples β in nnx the running statistics mutate in place when
use_running_average=False, so calling the block is just
block(x, use_running_average=False).
For BatchNorm on (H, W, C) input, the channel dimension is -1
by default and the running stats live as nnx.Variables sized
(C,).
Why no projection on the skip?
The βbasicβ block assumes the input and output have the same number
of channels β in_features == out_features == features. So the skip
connection is just the input x, no transform. When the channel
count changes (as in ResNetβs downsample blocks), the skip needs a
1x1 conv projection to match shapes β thatβs the βdownsampleβ or
βshortcutβ variant. We use the simpler equal-channels case here.
Worked sketch
class BasicBlock(nnx.Module):
def __init__(self, features, rngs):
self.conv1 = nnx.Conv(
in_features=features, out_features=features,
kernel_size=(3, 3), strides=1, padding="SAME", rngs=rngs,
)
self.bn1 = nnx.BatchNorm(num_features=features, rngs=rngs)
self.conv2 = nnx.Conv(
in_features=features, out_features=features,
kernel_size=(3, 3), strides=1, padding="SAME", rngs=rngs,
)
self.bn2 = nnx.BatchNorm(num_features=features, rngs=rngs)
def __call__(self, x, use_running_average):
identity = x
h = self.conv1(x)
h = self.bn1(h, use_running_average=use_running_average)
h = jax.nn.relu(h)
h = self.conv2(h)
h = self.bn2(h, use_running_average=use_running_average)
h = h + identity
return jax.nn.relu(h)
Compare with Linen, where the same block needs a separate
batch_stats variable collection and a mutable= argument at apply
time. In nnx bn1 and bn2 track their own running stats as
instance state.
Common pitfalls
-
Skipping the skip. Without
h = h + identity, you have a plain conv stack β no residual, no gradient highway. - ReLU between BN2 and skip-add. In the v1 block the second ReLU is AFTER the skip-add, not before. Order matters.
-
Forgetting the channel check. The basic block requires
C == features. If channels were different youβd need a 1x1 projection on the skip. -
Wrong
padding."SAME"keeps spatial dims constant so the skip-add is shape-compatible."VALID"would shrink and break. -
Storing BN modules as plain attributes vs sub-Modules. They
ARE sub-modules; assigning
self.bn1 = nnx.BatchNorm(...)is correct. Nothing extra needed.
Problem
Write resnet_basic_forward(seed, x, features):
-
BasicBlock(nnx.Module)withconv1,bn1,conv2,bn2. Each conv:kernel_size=(3, 3),strides=1,padding="SAME",in_features=features,out_features=features. -
__call__(x, use_running_average)does conv1 -> bn1 -> ReLU -> conv2 -> bn2 -> +skip -> ReLU. -
Cast
featuresto int. Build withnnx.Rngs(int(seed)). Apply withuse_running_average=False(training mode). -
Return
out.reshape(-1).
Inputs:
-
seed: int (passed as float). -
x: 3-D(H, W, C)channels-last;C == features. -
features: int (passed as float).
Output: 1-D flattened.
Hints
Sign in to attempt this problem and view the solution.