We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Implement Depthwise-Separable Convolution
Why this matters
Depthwise-separable convolution is the building block of MobileNet, EfficientNet, Xception, and most efficient mobile/edge models. The idea: split a regular conv into two cheaper steps and you get 5–10× fewer parameters and FLOPs for similar accuracy.
A regular (kh, kw) conv from C_in → C_out has parameter count
kh * kw * C_in * C_out.
A depthwise-separable conv replaces it with:
-
Depthwise
(kh, kw)conv: each input channel filtered with its OWN kernel (no cross-channel mixing). Params:kh * kw * C_in. -
Pointwise
(1, 1)conv: mixes channels with no spatial extent. Params:C_in * C_out.
Total: kh * kw * C_in + C_in * C_out. For typical settings
(kh=kw=3, C=64), that’s 576 + 4096 = 4672 vs 3 * 3 * 64 * 64 = 36864
— about 8× cheaper.
Depthwise via feature_group_count
jax.lax.conv_general_dilated has a feature_group_count argument:
-
feature_group_count=1(default): regular conv — every output channel depends on every input channel. -
feature_group_count=C_in: each input channel is filtered independently, producing one output per input.
For depthwise, we set feature_group_count=C_in. The kernel layout
is (kh, kw, 1, C_in) — one filter per input channel.
Pointwise = 1×1 conv
A pointwise conv is just Conv2D(features=C_out, kernel_size=(1,1)).
It mixes channels at each spatial location independently — equivalent
to a Dense layer applied to every pixel.
Worked structure
class DepthwiseSeparable(nn.Module):
features: int # output channels
kernel_h: int
kernel_w: int
@nn.compact
def __call__(self, x): # x: (H, W, C_in)
c_in = x.shape[-1]
depthwise_kernel = self.param(
"depthwise_kernel",
nn.initializers.lecun_normal(),
(self.kernel_h, self.kernel_w, 1, c_in), # (kh, kw, 1, C_in)
)
pointwise_kernel = self.param(
"pointwise_kernel",
nn.initializers.lecun_normal(),
(1, 1, c_in, self.features), # (1, 1, C_in, C_out)
)
bias = self.param("bias", nn.initializers.zeros, (self.features,))
x_b = x[None, ...]
# Step 1: depthwise — feature_group_count=c_in.
h = jax.lax.conv_general_dilated(
x_b, depthwise_kernel,
window_strides=(1, 1),
padding="SAME",
dimension_numbers=("NHWC", "HWIO", "NHWC"),
feature_group_count=c_in,
)
# Step 2: pointwise (1x1).
h = jax.lax.conv_general_dilated(
h, pointwise_kernel,
window_strides=(1, 1),
padding="SAME",
dimension_numbers=("NHWC", "HWIO", "NHWC"),
)
return h[0] + bias
Two parameters (depthwise_kernel, pointwise_kernel), one bias on the
output. Note: real implementations sometimes interleave a non-linearity
and norm between depthwise and pointwise (MobileNet-V2 style: depthwise
→ BN → ReLU → pointwise → BN). We omit those here for simplicity.
Common pitfalls
-
feature_group_count != C_in: any other value gives a “grouped conv” but not a proper depthwise. Set it toC_inexactly. -
Wrong depthwise kernel layout:
(kh, kw, 1, C_in)not(kh, kw, C_in, 1). TheI=1slot is the per-group input channel count (which is 1 for depthwise); theO=C_inslot is the total output channels. - Forgetting pointwise: depthwise alone doesn’t mix channels at all. Without pointwise, you’ve lost the ability to combine information across input channels.
Problem
Implement MyDepthwiseSeparable(features, kernel_h, kernel_w):
-
depthwise_kernel:(kernel_h, kernel_w, 1, C_in),lecun_normal(). -
pointwise_kernel:(1, 1, C_in, features),lecun_normal(). -
bias:(features,),zeros. -
Run depthwise with
feature_group_count=c_in, then pointwise. -
Return
.reshape(-1).
Use padding="SAME", stride (1, 1) throughout.
Inputs:
-
seed: float (cast to int). -
x: 3-D(H, W, C_in). -
features,kernel_h,kernel_w: floats (cast to int).
Output: 1-D flattened.
Hints
Sign in to attempt this problem and view the solution.