We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
1-D Conv via lax.conv_general_dilated
Why this matters
lax.conv_general_dilated is JAXโs lower-level convolution primitive. It
gives you full control over padding, strides, dilation, and dimension layout
โ everything that higher-level libraries like Flax or Optax build on top of.
Understanding it is essential for:
- Writing custom CNN architectures without a framework.
- Scientific computing and signal processing in JAX.
- Debugging shape errors in conv layers.
The dimension_numbers argument specifies how axes are laid out. For 1-D
conv, 'NCH' means batch (N), channel (C), length (H). The kernel layout
'OIH' means output-channels (O), input-channels (I), kernel-length (H).
Note: lax.conv_general_dilated performs CROSS-CORRELATION (no kernel
flip), unlike numpy.convolve which flips the kernel. For symmetric kernels
the result is identical; for asymmetric kernels the two differ.
Worked mini-example
import jax.numpy as jnp
from jax import lax
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
kernel = jnp.array([1.0, 1.0, 1.0])
# Reshape to 3-D: (batch=1, channel=1, length)
x_3d = x.reshape(1, 1, -1) # shape (1, 1, 5)
k_3d = kernel.reshape(1, 1, -1) # shape (1, 1, 3)
out = lax.conv_general_dilated(
x_3d, k_3d,
window_strides=(1,),
padding='VALID',
dimension_numbers=('NCH', 'OIH', 'NCH')
)
# out.reshape(-1) โ [6.0, 9.0, 12.0] (length 5 - 3 + 1 = 3)
Common pitfalls
-
Must reshape to 3-D โ
lax.conv_general_dilatedrequires (batch, channel, length); a 1-D array will error. -
dimension_numbersis mandatory โ the string tuple maps each axis of input, kernel, and output; get it wrong and shapes silently mismatch. -
VALID vs SAME โ
'VALID'shrinks the output (length L - K + 1);'SAME'pads with zeros to keep the same length as input. - Cross-correlation, not convolution โ kernel is NOT flipped.
Problem
Implement conv1d_basic(x, kernel) using lax.conv_general_dilated with
VALID padding.
-
x: 1-D jax array (length L). -
kernel: 1-D jax array (length K). - Returns: 1-D array of length L - K + 1.
Hints
Sign in to attempt this problem and view the solution.