We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Implement Conv1D from Scratch
Why this matters
1-D convolutions show up everywhere: audio processing, time-series models, text-via-character-CNN, and (importantly) as a stepping stone to 2-D conv. Reimplementing Conv1D forces you to confront three things head-on:
-
The kernel layout convention — Flax/JAX’s
WIO(width, in, out). -
Dimension numbers — JAX’s flexible but verbose way of telling
conv_general_dilatedwhat each axis means. -
The fact that
jax.lax.conv_general_dilatedis the universal conv primitive —nn.Convis just a thin wrapper that calls it with sensible defaults.
Input layout
Following Flax convention, Conv expects channels-last for 1-D inputs:
-
Input
x: shape(L, C_in)— sequence length × channels. -
Output
y: shape(L_out, C_out).
jax.lax.conv_general_dilated works on batched inputs, so we’ll add a
batch dim with x[None, ...], run the conv, then squeeze the batch out.
Kernel layout: WIO
The kernel for 1-D conv has shape (W, I, O):
-
W: kernel width (e.g., 3 for a width-3 filter). -
I: input channels. -
O: output channels (=features).
Compare with PyTorch’s (O, I, W) layout — Flax/JAX flips it for cache
locality reasons.
dimension_numbers tuple
conv_general_dilated takes a tuple (lhs_spec, rhs_spec, out_spec):
-
lhs_specdescribes input axes —"NWC"means(batch, width, channel). -
rhs_specdescribes kernel axes —"WIO". -
out_specdescribes output axes —"NWC"(same as input).
The strings are the layout’s order of axes; JAX uses them to permute as needed under the hood.
Worked example
import jax
import jax.numpy as jnp
x = jnp.array([1., 2., 3., 4.])[:, None] # shape (4, 1) — len=4, 1 channel
kernel = jnp.ones((3, 1, 2)) # WIO: width=3, in=1, out=2
y = jax.lax.conv_general_dilated(
x[None, ...], # add batch dim → (1, 4, 1)
kernel,
window_strides=(1,),
padding="SAME",
dimension_numbers=("NWC", "WIO", "NWC"),
)
# y shape (1, 4, 2). Drop batch → (4, 2).
With padding="SAME", output length equals input length. With
padding="VALID", length shrinks by kernel_size - 1.
Why use conv_general_dilated and not jnp.convolve?
jnp.convolve is 1-D, single-channel, and treats the kernel as flipped
(mathematical convolution). What we want is the multi-channel, batched
cross-correlation every neural net uses — that’s conv_general_dilated.
Once you’ve internalized this primitive, Conv2D and Conv3D are trivial: just change the dimension_numbers strings.
Common pitfalls
-
Wrong kernel layout:
(W, O, I)instead of(W, I, O)will reachinitwithout erroring (the shape passes), butapplywill produce garbage outputs. -
Forgetting the batch dim:
conv_general_dilatedrequires a batched input. Usex[None, ...]and squeeze the result withy[0]. -
Missing bias: bias has shape
(O,)— broadcast adds across L.
Problem
Implement MyConv1D(features, kernel_size):
-
Kernel shape
(kernel_size, x.shape[-1], features), initlecun_normal(). -
Bias shape
(features,), initzeros. -
Run conv with
window_strides=(1,),padding="SAME",dimension_numbers=("NWC", "WIO", "NWC")after adding a batch dim. - Add bias and squeeze batch dim.
-
Return the output flattened to 1-D for testing (
.reshape(-1)).
Inputs:
-
seed: float (cast to int). -
x: 2-D JAX array, shape(L, C_in). -
features: int (cast to int). -
kernel_size: int (cast to int).
Output: 1-D array, length L * features.
Hints
Sign in to attempt this problem and view the solution.