We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
Conv Padding & Strides
Why this matters
The padding argument to lax.conv_general_dilated controls how the output
length relates to the input length:
-
'VALID'— no padding is added. Output length isL - K + 1. The signal shrinks byK - 1elements. -
'SAME'— zeros are padded so that the output has the same length as the input (for stride 1). JAX follows TensorFlow convention: for even kernel sizes the extra pad goes at the end (asymmetric).
This choice appears everywhere in CNN design: VALID is common for feature extraction where you want to discard boundary effects; SAME is common for residual / skip-connection architectures that require matching shapes.
Worked mini-example
import jax.numpy as jnp
from jax import lax
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
kernel = jnp.array([1.0, 1.0, 1.0])
x_3d = x.reshape(1, 1, -1)
k_3d = kernel.reshape(1, 1, -1)
# VALID: output length 5 - 3 + 1 = 3
out_valid = lax.conv_general_dilated(
x_3d, k_3d, (1,), 'VALID', dimension_numbers=('NCH', 'OIH', 'NCH')
).reshape(-1)
# → [6.0, 9.0, 12.0]
# SAME: output length = input length = 5
out_same = lax.conv_general_dilated(
x_3d, k_3d, (1,), 'SAME', dimension_numbers=('NCH', 'OIH', 'NCH')
).reshape(-1)
# → [3.0, 6.0, 9.0, 12.0, 9.0]
Common pitfalls
- ‘SAME’ has asymmetric pad for even kernel sizes — for kernel size 2, JAX pads 0 at the front and 1 at the back. Front gets less padding.
-
‘VALID’ always shrinks — output length is
L - K + 1; for K > L this errors. -
Strides and padding interact — with stride > 1, ‘SAME’ output length
is
ceil(L / stride), notL. - Don’t confuse padding=0 with ‘VALID’ — explicit padding tuples are also valid but require computing the pad sizes manually.
Problem
Implement conv_with_padding(x, kernel, padding_mode) that applies 1-D
convolution using lax.conv_general_dilated with the padding selected by
padding_mode.
-
x: 1-D jax array. -
kernel: 1-D jax array. -
padding_mode: scalar —0.0means'VALID',1.0means'SAME'. - Returns: 1-D array.
Hints
jax
conv
padding
strides
Sign in to attempt this problem and view the solution.