We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
primitives
2-D Conv via lax.conv_general_dilated
Why this matters
2-D convolution with lax.conv_general_dilated extends the 1-D case to
image-shaped data. You need 4-D tensors: (batch, channel, H, W) โ the
standard PyTorch NCHW layout. This is the foundation of every CNN layer
written in JAX/Flax.
dimension_numbers=('NCHW', 'OIHW', 'NCHW') tells JAX:
- Input axes: N (batch), C (channels), H (height), W (width).
- Kernel axes: O (output-channels), I (input-channels), H, W.
- Output axes: same as input.
For TensorFlow-style (channels last), use ('NHWC', 'HWIO', 'NHWC').
Worked mini-example
import jax.numpy as jnp
from jax import lax
x = jnp.eye(3) # 3ร3 identity
kernel = jnp.ones((2, 2))
x_4d = x.reshape(1, 1, 3, 3)
k_4d = kernel.reshape(1, 1, 2, 2)
out = lax.conv_general_dilated(
x_4d, k_4d,
window_strides=(1, 1),
padding='VALID',
dimension_numbers=('NCHW', 'OIHW', 'NCHW')
)
result = out.reshape(out.shape[2], out.shape[3])
# result shape: (2, 2)
Common pitfalls
-
Must reshape to 4-D โ
lax.conv_general_dilatedrequires (batch, channel, H, W); a 2-D array will error. -
dimension_numbersmust match your layout โ using NCHW strings with NHWC data silently produces wrong results. -
Output shape with VALID โ
(H - Kh + 1, W - Kw + 1); the image shrinks by (Kh-1) rows and (Kw-1) columns. -
window_strides is 2-D โ must pass
(sh, sw), not a scalar.
Problem
Implement conv2d_basic(x, kernel) using lax.conv_general_dilated with
VALID padding and NCHW layout.
-
x: 2-D jax array (H, W). -
kernel: 2-D jax array (Kh, Kw). - Returns: 2-D array of shape (H - Kh + 1, W - Kw + 1).
Hints
jax
conv
lax
2d
Sign in to attempt this problem and view the solution.