hard primitives

Implement Depthwise-Separable Convolution

Why this matters

Depthwise-separable convolution is the building block of MobileNet, EfficientNet, Xception, and most efficient mobile/edge models. The idea: split a regular conv into two cheaper steps and you get 5–10× fewer parameters and FLOPs for similar accuracy.

A regular (kh, kw) conv from C_in → C_out has parameter count kh * kw * C_in * C_out.

A depthwise-separable conv replaces it with:

  1. Depthwise (kh, kw) conv: each input channel filtered with its OWN kernel (no cross-channel mixing). Params: kh * kw * C_in.
  2. Pointwise (1, 1) conv: mixes channels with no spatial extent. Params: C_in * C_out.

Total: kh * kw * C_in + C_in * C_out. For typical settings (kh=kw=3, C=64), that’s 576 + 4096 = 4672 vs 3 * 3 * 64 * 64 = 36864 — about 8× cheaper.

Depthwise via feature_group_count

jax.lax.conv_general_dilated has a feature_group_count argument:

  • feature_group_count=1 (default): regular conv — every output channel depends on every input channel.
  • feature_group_count=C_in: each input channel is filtered independently, producing one output per input.

For depthwise, we set feature_group_count=C_in. The kernel layout is (kh, kw, 1, C_in) — one filter per input channel.

Pointwise = 1×1 conv

A pointwise conv is just Conv2D(features=C_out, kernel_size=(1,1)). It mixes channels at each spatial location independently — equivalent to a Dense layer applied to every pixel.

Worked structure

class DepthwiseSeparable(nn.Module):
    features: int          # output channels
    kernel_h: int
    kernel_w: int

    @nn.compact
    def __call__(self, x):       # x: (H, W, C_in)
        c_in = x.shape[-1]
        depthwise_kernel = self.param(
            "depthwise_kernel",
            nn.initializers.lecun_normal(),
            (self.kernel_h, self.kernel_w, 1, c_in),  # (kh, kw, 1, C_in)
        )
        pointwise_kernel = self.param(
            "pointwise_kernel",
            nn.initializers.lecun_normal(),
            (1, 1, c_in, self.features),              # (1, 1, C_in, C_out)
        )
        bias = self.param("bias", nn.initializers.zeros, (self.features,))

        x_b = x[None, ...]
        # Step 1: depthwise — feature_group_count=c_in.
        h = jax.lax.conv_general_dilated(
            x_b, depthwise_kernel,
            window_strides=(1, 1),
            padding="SAME",
            dimension_numbers=("NHWC", "HWIO", "NHWC"),
            feature_group_count=c_in,
        )
        # Step 2: pointwise (1x1).
        h = jax.lax.conv_general_dilated(
            h, pointwise_kernel,
            window_strides=(1, 1),
            padding="SAME",
            dimension_numbers=("NHWC", "HWIO", "NHWC"),
        )
        return h[0] + bias

Two parameters (depthwise_kernel, pointwise_kernel), one bias on the output. Note: real implementations sometimes interleave a non-linearity and norm between depthwise and pointwise (MobileNet-V2 style: depthwise → BN → ReLU → pointwise → BN). We omit those here for simplicity.

Common pitfalls

  • feature_group_count != C_in: any other value gives a “grouped conv” but not a proper depthwise. Set it to C_in exactly.
  • Wrong depthwise kernel layout: (kh, kw, 1, C_in) not (kh, kw, C_in, 1). The I=1 slot is the per-group input channel count (which is 1 for depthwise); the O=C_in slot is the total output channels.
  • Forgetting pointwise: depthwise alone doesn’t mix channels at all. Without pointwise, you’ve lost the ability to combine information across input channels.

Problem

Implement MyDepthwiseSeparable(features, kernel_h, kernel_w):

  1. depthwise_kernel: (kernel_h, kernel_w, 1, C_in), lecun_normal().
  2. pointwise_kernel: (1, 1, C_in, features), lecun_normal().
  3. bias: (features,), zeros.
  4. Run depthwise with feature_group_count=c_in, then pointwise.
  5. Return .reshape(-1).

Use padding="SAME", stride (1, 1) throughout.

Inputs:

  • seed: float (cast to int).
  • x: 3-D (H, W, C_in).
  • features, kernel_h, kernel_w: floats (cast to int).

Output: 1-D flattened.

Hints

flax depthwise-separable mobilenet

Sign in to attempt this problem and view the solution.