medium primitives

Conv Padding & Strides

Why this matters

The padding argument to lax.conv_general_dilated controls how the output length relates to the input length:

  • 'VALID' — no padding is added. Output length is L - K + 1. The signal shrinks by K - 1 elements.
  • 'SAME' — zeros are padded so that the output has the same length as the input (for stride 1). JAX follows TensorFlow convention: for even kernel sizes the extra pad goes at the end (asymmetric).

This choice appears everywhere in CNN design: VALID is common for feature extraction where you want to discard boundary effects; SAME is common for residual / skip-connection architectures that require matching shapes.

Worked mini-example

import jax.numpy as jnp
from jax import lax

x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
kernel = jnp.array([1.0, 1.0, 1.0])

x_3d = x.reshape(1, 1, -1)
k_3d = kernel.reshape(1, 1, -1)

# VALID: output length 5 - 3 + 1 = 3
out_valid = lax.conv_general_dilated(
    x_3d, k_3d, (1,), 'VALID', dimension_numbers=('NCH', 'OIH', 'NCH')
).reshape(-1)
# → [6.0, 9.0, 12.0]

# SAME: output length = input length = 5
out_same = lax.conv_general_dilated(
    x_3d, k_3d, (1,), 'SAME', dimension_numbers=('NCH', 'OIH', 'NCH')
).reshape(-1)
# → [3.0, 6.0, 9.0, 12.0, 9.0]

Common pitfalls

  • ‘SAME’ has asymmetric pad for even kernel sizes — for kernel size 2, JAX pads 0 at the front and 1 at the back. Front gets less padding.
  • ‘VALID’ always shrinks — output length is L - K + 1; for K > L this errors.
  • Strides and padding interact — with stride > 1, ‘SAME’ output length is ceil(L / stride), not L.
  • Don’t confuse padding=0 with ‘VALID’ — explicit padding tuples are also valid but require computing the pad sizes manually.

Problem

Implement conv_with_padding(x, kernel, padding_mode) that applies 1-D convolution using lax.conv_general_dilated with the padding selected by padding_mode.

  • x: 1-D jax array.
  • kernel: 1-D jax array.
  • padding_mode: scalar — 0.0 means 'VALID', 1.0 means 'SAME'.
  • Returns: 1-D array.

Hints

jax conv padding strides

Sign in to attempt this problem and view the solution.