We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Tiny U-Net
Why this matters
U-Net (Ronneberger et al., 2015) is the canonical architecture for dense prediction — segmentation, image-to-image translation, diffusion model backbones. Its idea: build an encoder that progressively downsamples the input (the “contracting path”), pair it with a decoder that upsamples back to original resolution (the “expanding path”), and connect the two with skip connections that bring high-resolution features directly across.
The skips are the headline. Without them, the decoder would have to
rebuild fine-grained spatial detail from a tiny bottleneck — almost
impossible. With them, the decoder can read off the encoder’s
same-resolution features through concatenate and refine.
Architecture (2 encoder/decoder levels)
image (H, W, C_in)
|
v
ConvBNReLU(C_in -> f1) --> e1 (H, W, f1) ───────┐
v │ skip1
maxpool 2x2 -> (H/2, W/2, f1) │
v │
ConvBNReLU(f1 -> f2=2*f1) --> e2 (H/2, W/2, f2) ──┐ │
v │ │
maxpool 2x2 -> (H/4, W/4, f2) │ │
v │ │
ConvBNReLU(f2 -> f3=4*f1) -- bottom (H/4, W/4, f3) │ │
v │ │
ConvTranspose stride=2 -> (H/2, W/2, f2) │ │
concatenate [u2, e2] -> (H/2, W/2, 2*f2) <─────┘ │
ConvBNReLU(2*f2 -> f2) -> d2 (H/2, W/2, f2) │
v │
ConvTranspose stride=2 -> (H, W, f1) │
concatenate [u1, e1] -> (H, W, 2*f1) <──────────┘
ConvBNReLU(2*f1 -> f1) -> d1 (H, W, f1)
For the tests: input (4, 4, 3), base_features=4. So f1=4,
f2=8, f3=16. After 2 downsamples the bottleneck is (1, 1, 16),
which is correct for our toy size.
Reusable Conv-BN-ReLU block
Define a small ConvBNReLU(nnx.Module) so the rest of the network
is just composition:
class ConvBNReLU(nnx.Module):
def __init__(self, in_features, out_features, rngs):
self.conv = nnx.Conv(in_features=in_features, out_features=out_features,
kernel_size=(3, 3), strides=1, padding="SAME", rngs=rngs)
self.bn = nnx.BatchNorm(num_features=out_features, rngs=rngs)
def __call__(self, x, use_running_average):
x = self.conv(x)
x = self.bn(x, use_running_average=use_running_average)
return jax.nn.relu(x)
Max pooling without nnx
There’s no nnx.MaxPool we’ll use — jax.lax.reduce_window does it
cleanly:
from jax import lax
def max_pool_2x2(x):
return lax.reduce_window(
x, -jnp.inf, lax.max,
(2, 2, 1), # window
(2, 2, 1), # strides
"VALID",
)
Note: this is one of the few places -jnp.inf is appropriate. We’re
not normalizing through this max — reduce_window doesn’t have any
softmax-like operation that would NaN-propagate during autodiff. It’s
a pure max over a window; -jnp.inf is the standard identity for
max.
Upsampling: nnx.ConvTranspose
The decoder needs to go from (H/4, W/4, ...) back to
(H/2, W/2, ...). A 2x2 transposed convolution with stride 2 does
exactly that:
self.up2 = nnx.ConvTranspose(
in_features=f3, out_features=f2,
kernel_size=(2, 2), strides=2, padding="VALID", rngs=rngs,
)
The skip-concat trick
Right after each upsample, concatenate the same-resolution encoder output along the channel axis BEFORE the post-upsample conv:
u2 = self.up2(b) # (H/2, W/2, f2)
cat2 = jnp.concatenate([u2, e2], axis=-1) # (H/2, W/2, 2 * f2)
d2 = self.dec2(cat2, use_running_average=...)
The post-upsample dec2 ConvBNReLU therefore takes 2 * f2 input
channels (the upsampled tensor + the skip).
Why concatenate, not add?
Concatenation lets the decoder learn how much weight to give the skip vs the upsampled feature, channel-by-channel. Addition forces them to share representational space. Concat is simpler and works better for U-Net; ResNet uses addition because the residual is structurally a correction, not a separate channel of information.
Common pitfalls
-
Decoder conv with
fchannels instead of2*f. Forgot to account for the concat doubling the channels. The post-upsample conv takesf1 + f1 = 2*f1(orf2 + f2) inputs. - Concatenating along the wrong axis. Want axis=-1 (channels); axis=0 or 1 would stack/cat in space and explode shapes.
-
Skip from the wrong stage.
u2upsamples FROM the bottom back TO the e2 resolution; concat withe2. Don’t mix levels. -
Forgetting to call
_max_pool_2x2. Without pooling, the spatial dims never shrink and the bottleneck never gets to(1, 1, ...). - Stride mismatch in transposed conv. Use kernel_size=(2,2), strides=2 to upsample exactly 2x. Other settings give different output shapes.
Problem
Write unet_forward(seed, image, base_features):
-
Define a reusable
ConvBNReLU(Conv3x3-BN-ReLU module). -
Define
max_pool_2x2(x)viajax.lax.reduce_window. -
UNet(nnx.Module)withenc1,enc2,bottom(all ConvBNReLU),up2(nnx.ConvTransposef3 -> f2),dec2(ConvBNReLU2*f2 -> f2),up1(nnx.ConvTransposef2 -> f1),dec1(ConvBNReLU2*f1 -> f1).f1 = base_features,f2 = 2*f1,f3 = 4*f1. -
Forward: enc1 -> save -> pool -> enc2 -> save -> pool -> bottom
-> up2 -> concat skip2 -> dec2 -> up1 -> concat skip1 -> dec1.
Apply with
use_running_average=False. -
Cast
base_featuresto int. Returnout.reshape(-1).
Inputs:
-
seed: int (passed as float). -
image: 3-D(H, W, C)channels-last. -
base_features: int (passed as float).
Output: 1-D flattened (H * W * base_features,).
Hints
Sign in to attempt this problem and view the solution.