We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Squeeze-and-Excitation Block
Why this matters
Convolutions treat every channel equally. But across an image, different channels carry different signals — one might respond to “edges in this region,” another to “global lighting.” A network that could adaptively reweight channels based on global image content would have more capacity per parameter.
Squeeze-and-Excitation (SE) blocks (Hu et al., 2018) do exactly this. They became the secret sauce of:
- SENet (ImageNet 2017 winner).
- EfficientNet family (every block has an SE inside).
- MobileNet v3 (squeeze-and-excite gating in depthwise blocks).
- Many subsequent attention-style modules treat SE as the foundational “channel attention” pattern.
The recipe
Three steps. The names are mnemonic:
-
Squeeze: collapse spatial dims with global average pool.
(H, W, C) → (C,). -
Excite: a tiny MLP that produces per-channel gates.
Dense(C/r) → ReLU → Dense(C) → sigmoid. Output shape(C,), values in(0, 1). -
Scale: multiply the original feature map by these gates.
(H, W, C) * (C,) → (H, W, C)(broadcast).
The bottleneck C/r (typical r=16) is the SE block’s secret:
it FORCES the gates to be a low-dimensional summary, preventing
overfitting and keeping params light.
Squeeze, concretely
squeezed = jnp.mean(x, axis=(0, 1)) # (C,) — unbatched 3-D input
For a batched 4-D input you’d average over (1, 2) instead.
Each channel gets reduced to its global mean — that’s its
“summary statistic.”
Excite, concretely
h = nn.Dense(C // r)(squeezed) # bottleneck
h = nn.relu(h)
h = nn.Dense(C)(h) # back to full width
scale = jax.nn.sigmoid(h) # (C,)
The two-layer MLP creates non-linear gating: a channel can be boosted not just because IT had a high mean, but because the INTERACTION of its mean with other channels predicts it should be boosted.
Scale, concretely
out = x * scale # broadcasts (C,) over (H, W, C)
Per-channel multiplication. Each spatial position keeps its
relative pattern but its overall magnitude in each channel is
scaled by scale[c] ∈ (0, 1).
Worked walk-through
(H, W, C) = (4, 4, 4), reduction=2:
-
squeezed = jnp.mean(x, axis=(0, 1))→(4,). -
h = Dense(2)(squeezed) → relu → Dense(4)(h) → sigmoid→(4,). -
out = x * scale[None, None, :](or justx * scale— broadcasts on the last axis) →(4, 4, 4). - Flatten for tests.
Common pitfalls
-
Pooling over the wrong axes: must collapse SPATIAL dims,
not the channel dim. For
(H, W, C)unbatched, that’s(0, 1). -
Forgetting
sigmoid: without it,scalecan be any real number — channels can be flipped or amplified beyond 1, which is an entirely different (and unintended) behavior. -
Reducing too far: with
C=4, reduction=4,C/r = 1— a single-unit bottleneck still works, butC/r = 0would crash. Clamp tomax(C/r, 1). - Doing the bottleneck the wrong way: small in the middle, not at the start. The point is to compress THEN expand.
Problem
Implement se_block_forward(seed, x, reduction):
-
SEBlock(nn.Module)withreductionfield. -
Inside: squeeze with
jnp.mean(x, axis=(0, 1)). -
Excite:
Dense(max(C//r, 1)) → relu → Dense(C) → sigmoid. -
Scale: multiply
xby the per-channel gates. - Return flattened.
Inputs:
-
seed: int. -
x: 3-D(H, W, C). -
reduction: int (channel reduction ratio).
Output: 1-D, length H * W * C (same as input).
Hints
Sign in to attempt this problem and view the solution.