We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Implement Conv2D with Stride and Padding
Why this matters
Conv2D is the workhorse of vision. ResNets, U-Nets, EfficientNets — all
are stacks of Conv2D → norm → activation with strides for downsampling.
Reimplementing it solidifies the patterns from Conv1D in the higher-
dimensional case and forces you to confront how stride and SAME padding
interact.
Input layout
For 2-D conv:
-
Input
x:(H, W, C_in)— channels-LAST. (PyTorch isNCHW; Flax/JAX usesNHWCbecause XLA prefers it for memory layout on accelerators.) -
Output
y:(H_out, W_out, C_out).
Add batch dim → (1, H, W, C_in) for conv_general_dilated.
Kernel layout: HWIO
The kernel for 2-D conv has shape (kh, kw, C_in, C_out):
-
kh, kw: kernel height and width. -
C_in: input channels. -
C_out: output channels.
Stride and SAME padding
With padding="SAME" and stride=s:
-
Output spatial size =
ceil(H / s)(and likewise for W). -
JAX picks padding amounts so the output is exactly
H_out × s ≈ H.
With stride=1, output is same size as input. With stride=2,
output is roughly half (ceil(H/2)).
Examples for input H=3, W=3:
-
kernel
3×3,stride=1,SAME→ output(3, 3, ...). -
kernel
3×3,stride=2,SAME→ output(2, 2, ...).
Worked dimension_numbers
y = jax.lax.conv_general_dilated(
x[None, ...], # (1, H, W, C_in)
kernel, # (kh, kw, C_in, C_out)
window_strides=(stride, stride), # 2-D strides
padding="SAME",
dimension_numbers=("NHWC", "HWIO", "NHWC"),
)
# y shape: (1, H_out, W_out, C_out)
Drop batch dim and add bias.
Why care about layout?
On TPU and modern GPUs, NHWC + HWIO is the layout that allows XLA to
fuse conv with subsequent ops most effectively. PyTorch’s NCHW requires
a transpose under the hood for many GPU kernels — that’s “channels last”
in PyTorch parlance, and the source of many performance gotchas.
Flax just decided “channels last, always” — it’s a one-way choice that avoids the whole class of problem.
Common pitfalls
-
(C_in, C_out, kh, kw)layout (PyTorch-style): wrong. Flax is(kh, kw, C_in, C_out). -
Forgetting
x[None, ...]:conv_general_dilatedrequires a batch dim. The error is cryptic. -
Off-by-one in output shape with stride>1:
ceil(H / stride), notH // stride.
Problem
Implement MyConv2D(features, kernel_h, kernel_w, stride):
-
Kernel shape
(kernel_h, kernel_w, x.shape[-1], features), initlecun_normal(). -
Bias shape
(features,), initzeros. -
Add batch dim, call
conv_general_dilatedwithwindow_strides=(stride, stride),padding="SAME",dimension_numbers=("NHWC", "HWIO", "NHWC"). - Drop batch dim, add bias.
-
Return
.reshape(-1).
Inputs:
-
seed: float (cast to int). -
x: 3-D JAX array(H, W, C_in). -
features,kernel_h,kernel_w,stride: floats (cast to int).
Output: 1-D array (flattened conv output).
Hints
Sign in to attempt this problem and view the solution.