medium primitives

1-D Conv via lax.conv_general_dilated

Why this matters

lax.conv_general_dilated is JAXโ€™s lower-level convolution primitive. It gives you full control over padding, strides, dilation, and dimension layout โ€” everything that higher-level libraries like Flax or Optax build on top of.

Understanding it is essential for:

  • Writing custom CNN architectures without a framework.
  • Scientific computing and signal processing in JAX.
  • Debugging shape errors in conv layers.

The dimension_numbers argument specifies how axes are laid out. For 1-D conv, 'NCH' means batch (N), channel (C), length (H). The kernel layout 'OIH' means output-channels (O), input-channels (I), kernel-length (H).

Note: lax.conv_general_dilated performs CROSS-CORRELATION (no kernel flip), unlike numpy.convolve which flips the kernel. For symmetric kernels the result is identical; for asymmetric kernels the two differ.

Worked mini-example

import jax.numpy as jnp
from jax import lax

x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
kernel = jnp.array([1.0, 1.0, 1.0])

# Reshape to 3-D: (batch=1, channel=1, length)
x_3d = x.reshape(1, 1, -1)     # shape (1, 1, 5)
k_3d = kernel.reshape(1, 1, -1) # shape (1, 1, 3)

out = lax.conv_general_dilated(
    x_3d, k_3d,
    window_strides=(1,),
    padding='VALID',
    dimension_numbers=('NCH', 'OIH', 'NCH')
)
# out.reshape(-1) โ†’ [6.0, 9.0, 12.0]  (length 5 - 3 + 1 = 3)

Common pitfalls

  • Must reshape to 3-D โ€” lax.conv_general_dilated requires (batch, channel, length); a 1-D array will error.
  • dimension_numbers is mandatory โ€” the string tuple maps each axis of input, kernel, and output; get it wrong and shapes silently mismatch.
  • VALID vs SAME โ€” 'VALID' shrinks the output (length L - K + 1); 'SAME' pads with zeros to keep the same length as input.
  • Cross-correlation, not convolution โ€” kernel is NOT flipped.

Problem

Implement conv1d_basic(x, kernel) using lax.conv_general_dilated with VALID padding.

  • x: 1-D jax array (length L).
  • kernel: 1-D jax array (length K).
  • Returns: 1-D array of length L - K + 1.

Hints

jax conv lax

Sign in to attempt this problem and view the solution.