We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Implement Conv1D
Why this matters
1-D convolutions show up in audio, time-series, and character-CNNs —
and they’re a stepping stone to 2-D conv. Reimplementing Conv1D in nnx
pins down the same primitive (jax.lax.conv_general_dilated) you’d use
in Linen, but with the modern attribute-based parameter pattern. The
math is unchanged from any other framework; what changes is the
surrounding ceremony.
In Linen you’d carry a separate params dict and call apply with it.
In nnx the kernel and bias are just attributes of the module — same
as Dense, just with conv-shaped tensors and a different __call__.
Input layout
Flax convention is channels-LAST:
-
Input
x: shape(L, C_in)— sequence length × channels. -
Output
y: shape(L_out, C_out).
conv_general_dilated requires a batched input, so we add a batch dim
with x[None, ...], run the conv, then drop the batch with y[0].
Kernel layout: WIO
The 1-D conv kernel has shape (W, I, O):
-
W: kernel width. -
I: input channels (C_in). -
O: output channels (C_out=features).
PyTorch uses (O, I, W) — Flax/JAX flips it for cache locality on TPU.
dimension_numbers
conv_general_dilated((lhs_spec, rhs_spec, out_spec), ...):
-
lhs_spec = "NWC"— input axes are batch, width, channel. -
rhs_spec = "WIO"— kernel axes are width, in, out. -
out_spec = "NWC"— output axes match input.
Worked sketch
class MyConv1D(nnx.Module):
def __init__(self, in_features, out_features, kernel_size, rngs):
key = rngs.params()
init = jax.nn.initializers.lecun_normal()
self.kernel = nnx.Param(
init(key, (kernel_size, in_features, out_features))
)
self.bias = nnx.Param(jnp.zeros((out_features,)))
def __call__(self, x):
x_b = x[jnp.newaxis, ...] # (1, L, C_in)
y = jax.lax.conv_general_dilated(
x_b,
self.kernel.value,
window_strides=(1,),
padding="SAME",
dimension_numbers=("NWC", "WIO", "NWC"),
)
return y[0] + self.bias # (L, C_out)
Two things worth noting:
-
jax.nn.initializers.lecun_normal()returns a callable — call it with(key, shape)to materialize the initial values. It’s the same initializer Linen’snn.Convandnn.Denseuse by default. -
self.kernel.valueexplicitly unwraps the Param. Bareself.kernelalso works in math ops (becausennx.Paramdefines__array__/ arithmetic dunders), butconv_general_dilatedis picky — pass.value.
Why nnx makes Conv simpler
The math is identical to Linen, but you skip the init/apply round-trip:
no model.init(key, x) to allocate weights, no model.apply(params, x)
to invoke them. The first time the module is instantiated, the kernel
and bias exist; model(x) runs the forward.
Common pitfalls
-
Wrong kernel layout.
(W, O, I)instead of(W, I, O)runs without crashing but produces garbage. -
Forgetting the batch dim.
conv_general_dilatedrequires it. Add withx[None, ...], drop withy[0]. -
Bare
self.kernelintoconv_general_dilated. Useself.kernel.valueto pass the underlying JAX array. -
Casting
features/kernel_size. Numeric inputs arrive as floats; cast tointin__init__(or before passing in).
Problem
Write conv1d_forward(seed, x, features, kernel_size):
-
Define
MyConv1D(nnx.Module)with twonnx.Params:-
self.kernelshape(kernel_size, in_features, out_features), init withjax.nn.initializers.lecun_normal()(key, shape). -
self.biasshape(out_features,), zeros.
-
-
__call__: add batch dim, callconv_general_dilatedwithwindow_strides=(1,),padding="SAME",dimension_numbers=("NWC", "WIO", "NWC"). Drop batch dim, add bias. -
Build with
nnx.Rngs(int(seed)), instantiate (in_features=x.shape[-1],out_features=int(features),kernel_size=int(kernel_size)), returnmodel(x).reshape(-1).
Inputs:
-
seed: int (passed as float — cast to int). -
x: 2-D JAX array shape(L, C_in). -
features,kernel_size: ints (passed as floats — cast).
Output: 1-D array (flattened conv output).
Hints
Sign in to attempt this problem and view the solution.