We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Tiny U-Net
Why this matters
Classification answers βwhatβs in this image?β. Segmentation
answers βwhat is each pixel?β. The shapes are different β output
is not a (num_classes,) vector but an entire (H, W, num_classes)
map β so the architecture has to be different too.
U-Net (2015, biomedical imaging) is the canonical encoder- decoder for dense per-pixel prediction. The encoder downsamples to a low-resolution but information-rich representation; the decoder upsamples back to full resolution; skip connections splice high-resolution features from the encoder into the decoder at matching scales. Without those skips, the decoder would have to reconstruct fine spatial detail from a tiny bottleneck β it canβt.
Diffusion models (DDPM, Stable Diffusion) use U-Nets too: the same βdownsample, process, upsample with skipsβ recipe is what lets them denoise full-resolution images.
Architecture (this tiny version)
Encoder Bottom Decoder
image (H, W, C)
ββ ConvBlock(f) β e1 ββββββββββββββββββββββββββ concat β ConvBlock(f) β out
ββ MaxPool β p1
ββ ConvBlock(2f) β e2 ββββββββ concat β ConvBlock(2f)
ββ MaxPool β p2 β
ββ ConvBlock(4f) β b β
ββ ConvT βββ (upsample)
Each ConvBlock is Conv3x3 β BatchNorm β ReLU β keep it
simple. Going down: MaxPool(2x2, stride=2) halves spatial dims.
Going up: ConvTranspose(2x2, stride=2) doubles spatial dims.
At each decoder level, concatenate along the channel axis
with the matching encoder map BEFORE the conv block.
Skip connections, concretely
Concatenation, not addition (thatβs ResNetβs trick). For a 2-D
feature map (H, W, C):
decoder_in = jnp.concatenate([upsampled, encoder_skip], axis=-1)
# shape: (H, W, C_up + C_skip)
The next ConvBlock is responsible for fusing the two streams.
Channels DOUBLE at the concat β thatβs why the design has a
follow-up Conv to project them down again.
ConvTranspose for upsampling
nn.ConvTranspose(features, kernel_size=(2, 2), strides=(2, 2), padding='VALID') does the inverse of a strided conv: spatial
dims grow by stride. Output channels are set by features.
Other options exist (bilinear upsample + conv; pixel shuffle)
but transposed conv is the original U-Net choice and the most
parameter-flexible.
Worked walk-through
Input (4, 4, 1), base_features=2 (so f=2):
image (1, 4, 4, 1)
e1 β (1, 4, 4, 2) ConvBlock(2)
p1 β (1, 2, 2, 2) MaxPool 2x2 stride 2
e2 β (1, 2, 2, 4) ConvBlock(4)
p2 β (1, 1, 1, 4) MaxPool 2x2 stride 2
b β (1, 1, 1, 8) ConvBlock(8)
u2 β (1, 2, 2, 4) ConvTranspose(4)
d2 β (1, 2, 2, 8) concat(u2, e2)
d2 β (1, 2, 2, 4) ConvBlock(4)
u1 β (1, 4, 4, 2) ConvTranspose(2)
d1 β (1, 4, 4, 4) concat(u1, e1)
d1 β (1, 4, 4, 2) ConvBlock(2)
Final output is (1, 4, 4, 2) β same spatial size as the input,
f output channels (a βfeature headβ; in real segmentation youβd
add a final 1x1 conv to num_classes).
Common pitfalls
-
Concat axis wrong: along channel axis (
axis=-1for NHWC), NOT along spatial. Wrong axis β cryptic shape error or, worse, silently broken. -
MaxPool with
padding='SAME': would round up the output β keep'VALID'so the encoder/decoder shapes match exactly when you upsample by stride 2. - Forgetting the skip concat: classic β without skips, youβve built a vanilla autoencoder. Detail is lost.
-
Mismatched encoder/decoder spatial dims: the test inputs
use
H = W = 4(both divisible by 4) so two halvings + two doublings end up matching. With non-power-of-2 sizes youβd need pad/crop logic.
Problem
Implement unet_forward(seed, image, base_features):
-
Encoder:
ConvBlock(f) β MaxPool β ConvBlock(2f) β MaxPool. -
Bottom:
ConvBlock(4f). -
Decoder:
ConvTranspose(2f) β concat(skip e2) β ConvBlock(2f) β ConvTranspose(f) β concat(skip e1) β ConvBlock(f). -
Init/apply with batched input;
mutable=['batch_stats']. - Return flattened.
Each ConvBlock is Conv3x3('SAME') β BN(use_running_average=False) β ReLU.
Inputs:
-
seed: int. -
image: 3-D(H, W, C).H, Wdivisible by 4. -
base_features: int (encoderβs first stage channel count).
Output: 1-D flattened.
Hints
Sign in to attempt this problem and view the solution.