medium primitives

Squeeze-and-Excitation Block

Why this matters

Convolutions treat every channel equally. But across an image, different channels carry different signals — one might respond to “edges in this region,” another to “global lighting.” A network that could adaptively reweight channels based on global image content would have more capacity per parameter.

Squeeze-and-Excitation (SE) blocks (Hu et al., 2018) do exactly this. They became the secret sauce of:

  • SENet (ImageNet 2017 winner).
  • EfficientNet family (every block has an SE inside).
  • MobileNet v3 (squeeze-and-excite gating in depthwise blocks).
  • Many subsequent attention-style modules treat SE as the foundational “channel attention” pattern.

The recipe

Three steps. The names are mnemonic:

  1. Squeeze: collapse spatial dims with global average pool. (H, W, C) → (C,).
  2. Excite: a tiny MLP that produces per-channel gates. Dense(C/r) → ReLU → Dense(C) → sigmoid. Output shape (C,), values in (0, 1).
  3. Scale: multiply the original feature map by these gates. (H, W, C) * (C,) → (H, W, C) (broadcast).

The bottleneck C/r (typical r=16) is the SE block’s secret: it FORCES the gates to be a low-dimensional summary, preventing overfitting and keeping params light.

Squeeze, concretely

squeezed = jnp.mean(x, axis=(0, 1))   # (C,) — unbatched 3-D input

For a batched 4-D input you’d average over (1, 2) instead. Each channel gets reduced to its global mean — that’s its “summary statistic.”

Excite, concretely

h = nn.Dense(C // r)(squeezed)        # bottleneck
h = nn.relu(h)
h = nn.Dense(C)(h)                    # back to full width
scale = jax.nn.sigmoid(h)             # (C,)

The two-layer MLP creates non-linear gating: a channel can be boosted not just because IT had a high mean, but because the INTERACTION of its mean with other channels predicts it should be boosted.

Scale, concretely

out = x * scale                       # broadcasts (C,) over (H, W, C)

Per-channel multiplication. Each spatial position keeps its relative pattern but its overall magnitude in each channel is scaled by scale[c] ∈ (0, 1).

Worked walk-through

(H, W, C) = (4, 4, 4), reduction=2:

  1. squeezed = jnp.mean(x, axis=(0, 1))(4,).
  2. h = Dense(2)(squeezed) → relu → Dense(4)(h) → sigmoid(4,).
  3. out = x * scale[None, None, :] (or just x * scale — broadcasts on the last axis) → (4, 4, 4).
  4. Flatten for tests.

Common pitfalls

  • Pooling over the wrong axes: must collapse SPATIAL dims, not the channel dim. For (H, W, C) unbatched, that’s (0, 1).
  • Forgetting sigmoid: without it, scale can be any real number — channels can be flipped or amplified beyond 1, which is an entirely different (and unintended) behavior.
  • Reducing too far: with C=4, reduction=4, C/r = 1 — a single-unit bottleneck still works, but C/r = 0 would crash. Clamp to max(C/r, 1).
  • Doing the bottleneck the wrong way: small in the middle, not at the start. The point is to compress THEN expand.

Problem

Implement se_block_forward(seed, x, reduction):

  1. SEBlock(nn.Module) with reduction field.
  2. Inside: squeeze with jnp.mean(x, axis=(0, 1)).
  3. Excite: Dense(max(C//r, 1)) → relu → Dense(C) → sigmoid.
  4. Scale: multiply x by the per-channel gates.
  5. Return flattened.

Inputs:

  • seed: int.
  • x: 3-D (H, W, C).
  • reduction: int (channel reduction ratio).

Output: 1-D, length H * W * C (same as input).

Hints

flax se-block channel-attention senet

Sign in to attempt this problem and view the solution.