hard primitives

Tiny U-Net

Why this matters

Classification answers β€œwhat’s in this image?”. Segmentation answers β€œwhat is each pixel?”. The shapes are different β€” output is not a (num_classes,) vector but an entire (H, W, num_classes) map β€” so the architecture has to be different too.

U-Net (2015, biomedical imaging) is the canonical encoder- decoder for dense per-pixel prediction. The encoder downsamples to a low-resolution but information-rich representation; the decoder upsamples back to full resolution; skip connections splice high-resolution features from the encoder into the decoder at matching scales. Without those skips, the decoder would have to reconstruct fine spatial detail from a tiny bottleneck β€” it can’t.

Diffusion models (DDPM, Stable Diffusion) use U-Nets too: the same β€œdownsample, process, upsample with skips” recipe is what lets them denoise full-resolution images.

Architecture (this tiny version)

                Encoder                  Bottom            Decoder
    image (H, W, C)
        └─ ConvBlock(f) ─ e1 ────────────────────────── concat ─ ConvBlock(f) ─ out
              └─ MaxPool ─ p1
                   └─ ConvBlock(2f) ─ e2 ──────── concat ─ ConvBlock(2f)
                          └─ MaxPool ─ p2                ↑
                               └─ ConvBlock(4f) ─ b      ↑
                                              └─ ConvT β”€β”€β”˜ (upsample)

Each ConvBlock is Conv3x3 β†’ BatchNorm β†’ ReLU β€” keep it simple. Going down: MaxPool(2x2, stride=2) halves spatial dims. Going up: ConvTranspose(2x2, stride=2) doubles spatial dims. At each decoder level, concatenate along the channel axis with the matching encoder map BEFORE the conv block.

Skip connections, concretely

Concatenation, not addition (that’s ResNet’s trick). For a 2-D feature map (H, W, C):

decoder_in = jnp.concatenate([upsampled, encoder_skip], axis=-1)
# shape: (H, W, C_up + C_skip)

The next ConvBlock is responsible for fusing the two streams. Channels DOUBLE at the concat β€” that’s why the design has a follow-up Conv to project them down again.

ConvTranspose for upsampling

nn.ConvTranspose(features, kernel_size=(2, 2), strides=(2, 2), padding='VALID') does the inverse of a strided conv: spatial dims grow by stride. Output channels are set by features. Other options exist (bilinear upsample + conv; pixel shuffle) but transposed conv is the original U-Net choice and the most parameter-flexible.

Worked walk-through

Input (4, 4, 1), base_features=2 (so f=2):

image     (1, 4, 4, 1)
  e1  β†’   (1, 4, 4, 2)        ConvBlock(2)
  p1  β†’   (1, 2, 2, 2)        MaxPool 2x2 stride 2
  e2  β†’   (1, 2, 2, 4)        ConvBlock(4)
  p2  β†’   (1, 1, 1, 4)        MaxPool 2x2 stride 2
  b   β†’   (1, 1, 1, 8)        ConvBlock(8)
  u2  β†’   (1, 2, 2, 4)        ConvTranspose(4)
  d2  β†’   (1, 2, 2, 8)        concat(u2, e2)
  d2  β†’   (1, 2, 2, 4)        ConvBlock(4)
  u1  β†’   (1, 4, 4, 2)        ConvTranspose(2)
  d1  β†’   (1, 4, 4, 4)        concat(u1, e1)
  d1  β†’   (1, 4, 4, 2)        ConvBlock(2)

Final output is (1, 4, 4, 2) β€” same spatial size as the input, f output channels (a β€œfeature head”; in real segmentation you’d add a final 1x1 conv to num_classes).

Common pitfalls

  • Concat axis wrong: along channel axis (axis=-1 for NHWC), NOT along spatial. Wrong axis β†’ cryptic shape error or, worse, silently broken.
  • MaxPool with padding='SAME': would round up the output β€” keep 'VALID' so the encoder/decoder shapes match exactly when you upsample by stride 2.
  • Forgetting the skip concat: classic β€” without skips, you’ve built a vanilla autoencoder. Detail is lost.
  • Mismatched encoder/decoder spatial dims: the test inputs use H = W = 4 (both divisible by 4) so two halvings + two doublings end up matching. With non-power-of-2 sizes you’d need pad/crop logic.

Problem

Implement unet_forward(seed, image, base_features):

  1. Encoder: ConvBlock(f) β†’ MaxPool β†’ ConvBlock(2f) β†’ MaxPool.
  2. Bottom: ConvBlock(4f).
  3. Decoder: ConvTranspose(2f) β†’ concat(skip e2) β†’ ConvBlock(2f) β†’ ConvTranspose(f) β†’ concat(skip e1) β†’ ConvBlock(f).
  4. Init/apply with batched input; mutable=['batch_stats'].
  5. Return flattened.

Each ConvBlock is Conv3x3('SAME') β†’ BN(use_running_average=False) β†’ ReLU.

Inputs:

  • seed: int.
  • image: 3-D (H, W, C). H, W divisible by 4.
  • base_features: int (encoder’s first stage channel count).

Output: 1-D flattened.

Hints

flax unet encoder-decoder skip-connection

Sign in to attempt this problem and view the solution.