medium primitives

Implement Conv1D from Scratch

Why this matters

1-D convolutions show up everywhere: audio processing, time-series models, text-via-character-CNN, and (importantly) as a stepping stone to 2-D conv. Reimplementing Conv1D forces you to confront three things head-on:

  1. The kernel layout convention — Flax/JAX’s WIO (width, in, out).
  2. Dimension numbers — JAX’s flexible but verbose way of telling conv_general_dilated what each axis means.
  3. The fact that jax.lax.conv_general_dilated is the universal conv primitive — nn.Conv is just a thin wrapper that calls it with sensible defaults.

Input layout

Following Flax convention, Conv expects channels-last for 1-D inputs:

  • Input x: shape (L, C_in) — sequence length × channels.
  • Output y: shape (L_out, C_out).

jax.lax.conv_general_dilated works on batched inputs, so we’ll add a batch dim with x[None, ...], run the conv, then squeeze the batch out.

Kernel layout: WIO

The kernel for 1-D conv has shape (W, I, O):

  • W: kernel width (e.g., 3 for a width-3 filter).
  • I: input channels.
  • O: output channels (= features).

Compare with PyTorch’s (O, I, W) layout — Flax/JAX flips it for cache locality reasons.

dimension_numbers tuple

conv_general_dilated takes a tuple (lhs_spec, rhs_spec, out_spec):

  • lhs_spec describes input axes — "NWC" means (batch, width, channel).
  • rhs_spec describes kernel axes — "WIO".
  • out_spec describes output axes — "NWC" (same as input).

The strings are the layout’s order of axes; JAX uses them to permute as needed under the hood.

Worked example

import jax
import jax.numpy as jnp

x = jnp.array([1., 2., 3., 4.])[:, None]  # shape (4, 1) — len=4, 1 channel
kernel = jnp.ones((3, 1, 2))               # WIO: width=3, in=1, out=2

y = jax.lax.conv_general_dilated(
    x[None, ...],              # add batch dim → (1, 4, 1)
    kernel,
    window_strides=(1,),
    padding="SAME",
    dimension_numbers=("NWC", "WIO", "NWC"),
)
# y shape (1, 4, 2). Drop batch → (4, 2).

With padding="SAME", output length equals input length. With padding="VALID", length shrinks by kernel_size - 1.

Why use conv_general_dilated and not jnp.convolve?

jnp.convolve is 1-D, single-channel, and treats the kernel as flipped (mathematical convolution). What we want is the multi-channel, batched cross-correlation every neural net uses — that’s conv_general_dilated.

Once you’ve internalized this primitive, Conv2D and Conv3D are trivial: just change the dimension_numbers strings.

Common pitfalls

  • Wrong kernel layout: (W, O, I) instead of (W, I, O) will reach init without erroring (the shape passes), but apply will produce garbage outputs.
  • Forgetting the batch dim: conv_general_dilated requires a batched input. Use x[None, ...] and squeeze the result with y[0].
  • Missing bias: bias has shape (O,) — broadcast adds across L.

Problem

Implement MyConv1D(features, kernel_size):

  1. Kernel shape (kernel_size, x.shape[-1], features), init lecun_normal().
  2. Bias shape (features,), init zeros.
  3. Run conv with window_strides=(1,), padding="SAME", dimension_numbers=("NWC", "WIO", "NWC") after adding a batch dim.
  4. Add bias and squeeze batch dim.
  5. Return the output flattened to 1-D for testing (.reshape(-1)).

Inputs:

  • seed: float (cast to int).
  • x: 2-D JAX array, shape (L, C_in).
  • features: int (cast to int).
  • kernel_size: int (cast to int).

Output: 1-D array, length L * features.

Hints

flax conv1d self-param

Sign in to attempt this problem and view the solution.