medium primitives

NNX Implement Conv1D

Why this matters

1-D convolutions show up in audio, time-series, and character-CNNs — and they’re a stepping stone to 2-D conv. Reimplementing Conv1D in nnx pins down the same primitive (jax.lax.conv_general_dilated) you’d use in Linen, but with the modern attribute-based parameter pattern. The math is unchanged from any other framework; what changes is the surrounding ceremony.

In Linen you’d carry a separate params dict and call apply with it. In nnx the kernel and bias are just attributes of the module — same as Dense, just with conv-shaped tensors and a different __call__.

Input layout

Flax convention is channels-LAST:

  • Input x: shape (L, C_in) — sequence length × channels.
  • Output y: shape (L_out, C_out).

conv_general_dilated requires a batched input, so we add a batch dim with x[None, ...], run the conv, then drop the batch with y[0].

Kernel layout: WIO

The 1-D conv kernel has shape (W, I, O):

  • W: kernel width.
  • I: input channels (C_in).
  • O: output channels (C_out = features).

PyTorch uses (O, I, W) — Flax/JAX flips it for cache locality on TPU.

dimension_numbers

conv_general_dilated((lhs_spec, rhs_spec, out_spec), ...):

  • lhs_spec = "NWC" — input axes are batch, width, channel.
  • rhs_spec = "WIO" — kernel axes are width, in, out.
  • out_spec = "NWC" — output axes match input.

Worked sketch

class MyConv1D(nnx.Module):
    def __init__(self, in_features, out_features, kernel_size, rngs):
        key = rngs.params()
        init = jax.nn.initializers.lecun_normal()
        self.kernel = nnx.Param(
            init(key, (kernel_size, in_features, out_features))
        )
        self.bias = nnx.Param(jnp.zeros((out_features,)))

    def __call__(self, x):
        x_b = x[jnp.newaxis, ...]                      # (1, L, C_in)
        y = jax.lax.conv_general_dilated(
            x_b,
            self.kernel.value,
            window_strides=(1,),
            padding="SAME",
            dimension_numbers=("NWC", "WIO", "NWC"),
        )
        return y[0] + self.bias                         # (L, C_out)

Two things worth noting:

  • jax.nn.initializers.lecun_normal() returns a callable — call it with (key, shape) to materialize the initial values. It’s the same initializer Linen’s nn.Conv and nn.Dense use by default.
  • self.kernel.value explicitly unwraps the Param. Bare self.kernel also works in math ops (because nnx.Param defines __array__ / arithmetic dunders), but conv_general_dilated is picky — pass .value.

Why nnx makes Conv simpler

The math is identical to Linen, but you skip the init/apply round-trip: no model.init(key, x) to allocate weights, no model.apply(params, x) to invoke them. The first time the module is instantiated, the kernel and bias exist; model(x) runs the forward.

Common pitfalls

  • Wrong kernel layout. (W, O, I) instead of (W, I, O) runs without crashing but produces garbage.
  • Forgetting the batch dim. conv_general_dilated requires it. Add with x[None, ...], drop with y[0].
  • Bare self.kernel into conv_general_dilated. Use self.kernel.value to pass the underlying JAX array.
  • Casting features/kernel_size. Numeric inputs arrive as floats; cast to int in __init__ (or before passing in).

Problem

Write conv1d_forward(seed, x, features, kernel_size):

  1. Define MyConv1D(nnx.Module) with two nnx.Params:
    • self.kernel shape (kernel_size, in_features, out_features), init with jax.nn.initializers.lecun_normal()(key, shape).
    • self.bias shape (out_features,), zeros.
  2. __call__: add batch dim, call conv_general_dilated with window_strides=(1,), padding="SAME", dimension_numbers=("NWC", "WIO", "NWC"). Drop batch dim, add bias.
  3. Build with nnx.Rngs(int(seed)), instantiate (in_features=x.shape[-1], out_features=int(features), kernel_size=int(kernel_size)), return model(x).reshape(-1).

Inputs:

  • seed: int (passed as float — cast to int).
  • x: 2-D JAX array shape (L, C_in).
  • features, kernel_size: ints (passed as floats — cast).

Output: 1-D array (flattened conv output).

Hints

flax nnx conv1d reimplementation

Sign in to attempt this problem and view the solution.