medium primitives

Implement Conv2D with Stride and Padding

Why this matters

Conv2D is the workhorse of vision. ResNets, U-Nets, EfficientNets — all are stacks of Conv2D → norm → activation with strides for downsampling. Reimplementing it solidifies the patterns from Conv1D in the higher- dimensional case and forces you to confront how stride and SAME padding interact.

Input layout

For 2-D conv:

  • Input x: (H, W, C_in) — channels-LAST. (PyTorch is NCHW; Flax/JAX uses NHWC because XLA prefers it for memory layout on accelerators.)
  • Output y: (H_out, W_out, C_out).

Add batch dim → (1, H, W, C_in) for conv_general_dilated.

Kernel layout: HWIO

The kernel for 2-D conv has shape (kh, kw, C_in, C_out):

  • kh, kw: kernel height and width.
  • C_in: input channels.
  • C_out: output channels.

Stride and SAME padding

With padding="SAME" and stride=s:

  • Output spatial size = ceil(H / s) (and likewise for W).
  • JAX picks padding amounts so the output is exactly H_out × s ≈ H.

With stride=1, output is same size as input. With stride=2, output is roughly half (ceil(H/2)).

Examples for input H=3, W=3:

  • kernel 3×3, stride=1, SAME → output (3, 3, ...).
  • kernel 3×3, stride=2, SAME → output (2, 2, ...).

Worked dimension_numbers

y = jax.lax.conv_general_dilated(
    x[None, ...],                                    # (1, H, W, C_in)
    kernel,                                          # (kh, kw, C_in, C_out)
    window_strides=(stride, stride),                 # 2-D strides
    padding="SAME",
    dimension_numbers=("NHWC", "HWIO", "NHWC"),
)
# y shape: (1, H_out, W_out, C_out)

Drop batch dim and add bias.

Why care about layout?

On TPU and modern GPUs, NHWC + HWIO is the layout that allows XLA to fuse conv with subsequent ops most effectively. PyTorch’s NCHW requires a transpose under the hood for many GPU kernels — that’s “channels last” in PyTorch parlance, and the source of many performance gotchas.

Flax just decided “channels last, always” — it’s a one-way choice that avoids the whole class of problem.

Common pitfalls

  • (C_in, C_out, kh, kw) layout (PyTorch-style): wrong. Flax is (kh, kw, C_in, C_out).
  • Forgetting x[None, ...]: conv_general_dilated requires a batch dim. The error is cryptic.
  • Off-by-one in output shape with stride>1: ceil(H / stride), not H // stride.

Problem

Implement MyConv2D(features, kernel_h, kernel_w, stride):

  1. Kernel shape (kernel_h, kernel_w, x.shape[-1], features), init lecun_normal().
  2. Bias shape (features,), init zeros.
  3. Add batch dim, call conv_general_dilated with window_strides=(stride, stride), padding="SAME", dimension_numbers=("NHWC", "HWIO", "NHWC").
  4. Drop batch dim, add bias.
  5. Return .reshape(-1).

Inputs:

  • seed: float (cast to int).
  • x: 3-D JAX array (H, W, C_in).
  • features, kernel_h, kernel_w, stride: floats (cast to int).

Output: 1-D array (flattened conv output).

Hints

flax conv2d self-param

Sign in to attempt this problem and view the solution.