hard primitives

2-D Conv via lax.conv_general_dilated

Why this matters

2-D convolution with lax.conv_general_dilated extends the 1-D case to image-shaped data. You need 4-D tensors: (batch, channel, H, W) โ€” the standard PyTorch NCHW layout. This is the foundation of every CNN layer written in JAX/Flax.

dimension_numbers=('NCHW', 'OIHW', 'NCHW') tells JAX:

  • Input axes: N (batch), C (channels), H (height), W (width).
  • Kernel axes: O (output-channels), I (input-channels), H, W.
  • Output axes: same as input.

For TensorFlow-style (channels last), use ('NHWC', 'HWIO', 'NHWC').

Worked mini-example

import jax.numpy as jnp
from jax import lax

x = jnp.eye(3)          # 3ร—3 identity
kernel = jnp.ones((2, 2))

x_4d = x.reshape(1, 1, 3, 3)
k_4d = kernel.reshape(1, 1, 2, 2)

out = lax.conv_general_dilated(
    x_4d, k_4d,
    window_strides=(1, 1),
    padding='VALID',
    dimension_numbers=('NCHW', 'OIHW', 'NCHW')
)
result = out.reshape(out.shape[2], out.shape[3])
# result shape: (2, 2)

Common pitfalls

  • Must reshape to 4-D โ€” lax.conv_general_dilated requires (batch, channel, H, W); a 2-D array will error.
  • dimension_numbers must match your layout โ€” using NCHW strings with NHWC data silently produces wrong results.
  • Output shape with VALID โ€” (H - Kh + 1, W - Kw + 1); the image shrinks by (Kh-1) rows and (Kw-1) columns.
  • window_strides is 2-D โ€” must pass (sh, sw), not a scalar.

Problem

Implement conv2d_basic(x, kernel) using lax.conv_general_dilated with VALID padding and NCHW layout.

  • x: 2-D jax array (H, W).
  • kernel: 2-D jax array (Kh, Kw).
  • Returns: 2-D array of shape (H - Kh + 1, W - Kw + 1).

Hints

jax conv lax 2d

Sign in to attempt this problem and view the solution.