hard primitives

NNX Tiny U-Net

Why this matters

U-Net (Ronneberger et al., 2015) is the canonical architecture for dense prediction — segmentation, image-to-image translation, diffusion model backbones. Its idea: build an encoder that progressively downsamples the input (the “contracting path”), pair it with a decoder that upsamples back to original resolution (the “expanding path”), and connect the two with skip connections that bring high-resolution features directly across.

The skips are the headline. Without them, the decoder would have to rebuild fine-grained spatial detail from a tiny bottleneck — almost impossible. With them, the decoder can read off the encoder’s same-resolution features through concatenate and refine.

Architecture (2 encoder/decoder levels)

image (H, W, C_in)
   |
   v
ConvBNReLU(C_in -> f1)        --> e1 (H, W, f1)         ───────┐
   v                                                            │ skip1
maxpool 2x2  -> (H/2, W/2, f1)                                  │
   v                                                            │
ConvBNReLU(f1 -> f2=2*f1)     --> e2 (H/2, W/2, f2)     ──┐    │
   v                                                       │    │
maxpool 2x2  -> (H/4, W/4, f2)                             │    │
   v                                                       │    │
ConvBNReLU(f2 -> f3=4*f1)     -- bottom (H/4, W/4, f3)    │    │
   v                                                       │    │
ConvTranspose stride=2  -> (H/2, W/2, f2)                  │    │
   concatenate [u2, e2]  -> (H/2, W/2, 2*f2)         <─────┘    │
   ConvBNReLU(2*f2 -> f2) -> d2 (H/2, W/2, f2)                  │
   v                                                            │
ConvTranspose stride=2  -> (H, W, f1)                           │
   concatenate [u1, e1]  -> (H, W, 2*f1)            <──────────┘
   ConvBNReLU(2*f1 -> f1) -> d1 (H, W, f1)

For the tests: input (4, 4, 3), base_features=4. So f1=4, f2=8, f3=16. After 2 downsamples the bottleneck is (1, 1, 16), which is correct for our toy size.

Reusable Conv-BN-ReLU block

Define a small ConvBNReLU(nnx.Module) so the rest of the network is just composition:

class ConvBNReLU(nnx.Module):
    def __init__(self, in_features, out_features, rngs):
        self.conv = nnx.Conv(in_features=in_features, out_features=out_features,
                             kernel_size=(3, 3), strides=1, padding="SAME", rngs=rngs)
        self.bn = nnx.BatchNorm(num_features=out_features, rngs=rngs)

    def __call__(self, x, use_running_average):
        x = self.conv(x)
        x = self.bn(x, use_running_average=use_running_average)
        return jax.nn.relu(x)

Max pooling without nnx

There’s no nnx.MaxPool we’ll use — jax.lax.reduce_window does it cleanly:

from jax import lax

def max_pool_2x2(x):
    return lax.reduce_window(
        x, -jnp.inf, lax.max,
        (2, 2, 1),         # window
        (2, 2, 1),         # strides
        "VALID",
    )

Note: this is one of the few places -jnp.inf is appropriate. We’re not normalizing through this max — reduce_window doesn’t have any softmax-like operation that would NaN-propagate during autodiff. It’s a pure max over a window; -jnp.inf is the standard identity for max.

Upsampling: nnx.ConvTranspose

The decoder needs to go from (H/4, W/4, ...) back to (H/2, W/2, ...). A 2x2 transposed convolution with stride 2 does exactly that:

self.up2 = nnx.ConvTranspose(
    in_features=f3, out_features=f2,
    kernel_size=(2, 2), strides=2, padding="VALID", rngs=rngs,
)

The skip-concat trick

Right after each upsample, concatenate the same-resolution encoder output along the channel axis BEFORE the post-upsample conv:

u2 = self.up2(b)                        # (H/2, W/2, f2)
cat2 = jnp.concatenate([u2, e2], axis=-1)   # (H/2, W/2, 2 * f2)
d2 = self.dec2(cat2, use_running_average=...)

The post-upsample dec2 ConvBNReLU therefore takes 2 * f2 input channels (the upsampled tensor + the skip).

Why concatenate, not add?

Concatenation lets the decoder learn how much weight to give the skip vs the upsampled feature, channel-by-channel. Addition forces them to share representational space. Concat is simpler and works better for U-Net; ResNet uses addition because the residual is structurally a correction, not a separate channel of information.

Common pitfalls

  • Decoder conv with f channels instead of 2*f. Forgot to account for the concat doubling the channels. The post-upsample conv takes f1 + f1 = 2*f1 (or f2 + f2) inputs.
  • Concatenating along the wrong axis. Want axis=-1 (channels); axis=0 or 1 would stack/cat in space and explode shapes.
  • Skip from the wrong stage. u2 upsamples FROM the bottom back TO the e2 resolution; concat with e2. Don’t mix levels.
  • Forgetting to call _max_pool_2x2. Without pooling, the spatial dims never shrink and the bottleneck never gets to (1, 1, ...).
  • Stride mismatch in transposed conv. Use kernel_size=(2,2), strides=2 to upsample exactly 2x. Other settings give different output shapes.

Problem

Write unet_forward(seed, image, base_features):

  1. Define a reusable ConvBNReLU (Conv3x3-BN-ReLU module).
  2. Define max_pool_2x2(x) via jax.lax.reduce_window.
  3. UNet(nnx.Module) with enc1, enc2, bottom (all ConvBNReLU), up2 (nnx.ConvTranspose f3 -> f2), dec2 (ConvBNReLU 2*f2 -> f2), up1 (nnx.ConvTranspose f2 -> f1), dec1 (ConvBNReLU 2*f1 -> f1). f1 = base_features, f2 = 2*f1, f3 = 4*f1.
  4. Forward: enc1 -> save -> pool -> enc2 -> save -> pool -> bottom -> up2 -> concat skip2 -> dec2 -> up1 -> concat skip1 -> dec1. Apply with use_running_average=False.
  5. Cast base_features to int. Return out.reshape(-1).

Inputs:

  • seed: int (passed as float).
  • image: 3-D (H, W, C) channels-last.
  • base_features: int (passed as float).

Output: 1-D flattened (H * W * base_features,).

Hints

flax nnx unet encoder-decoder skip-connections architecture

Sign in to attempt this problem and view the solution.