medium primitives

NNX Implement Conv2D

Why this matters

Conv2D is the workhorse of vision. ResNets, U-Nets, EfficientNets, ViT patch embeddings — all are stacks of Conv2D → norm → activation with strides for downsampling. Reimplementing it in nnx is the same exercise you’d do in Linen, with one fewer layer of ceremony: parameters live on the module, calling runs the forward.

The math is identical to Conv1D — just one more spatial axis. If the last problem made sense, this one should feel mechanical.

Input layout

For 2-D conv, Flax convention is channels-LAST:

  • Input x: (H, W, C_in) — height, width, in-channels.
  • Output y: (H_out, W_out, C_out).

Add a batch dim → (1, H, W, C_in) for conv_general_dilated.

PyTorch is NCHW (channels first). Flax/JAX uses NHWC because XLA prefers it for memory layout on TPU and modern GPUs. PyTorch’s “channels-last” mode is exactly this layout, but PyTorch defaults to NCHW and reorders under the hood for many GPU kernels — Flax just decided “channels-last, always” and avoided the entire class of performance gotchas.

Kernel layout: HWIO

The 2-D conv kernel has shape (kh, kw, C_in, C_out):

  • kh, kw: kernel height and width.
  • C_in: input channels.
  • C_out: output channels.

Stride and SAME padding

With padding="SAME" and stride=s:

  • Output spatial size = ceil(H / s) (and likewise for W).
  • JAX picks padding amounts so the output is exactly H_out * s ≈ H.

With stride=1, output is same size as input. With stride=2, output is roughly half (ceil(H/2)).

Examples for input H=3, W=3:

  • kernel 3x3, stride=1, SAME → output (3, 3, ...).
  • kernel 3x3, stride=2, SAME → output (2, 2, ...).

Worked sketch

class MyConv2D(nnx.Module):
    def __init__(self, in_features, out_features, kernel_h, kernel_w,
                 stride, rngs):
        key = rngs.params()
        init = jax.nn.initializers.lecun_normal()
        self.kernel = nnx.Param(
            init(key, (kernel_h, kernel_w, in_features, out_features))
        )
        self.bias = nnx.Param(jnp.zeros((out_features,)))
        self.stride = stride                      # plain attribute, not a Param

    def __call__(self, x):
        x_b = x[jnp.newaxis, ...]                  # (1, H, W, C_in)
        y = jax.lax.conv_general_dilated(
            x_b,
            self.kernel.value,
            window_strides=(self.stride, self.stride),
            padding="SAME",
            dimension_numbers=("NHWC", "HWIO", "NHWC"),
        )
        return y[0] + self.bias                     # (H_out, W_out, C_out)

self.stride = stride (without nnx.Param or nnx.Variable) is a static attribute — not part of the module’s state pytree, but perfectly fine for fixed hyperparameters like stride or padding mode.

Why nnx makes this clean

Linen’s nn.Conv is a wrapper around conv_general_dilated. nnx’s nnx.Conv is the same wrapper. When you write your own conv layer in nnx, you reach for conv_general_dilated directly — and the rest of the module is just attribute assignments. No @nn.compact, no self.param (those are Linen idioms).

Common pitfalls

  • (C_in, C_out, kh, kw) (PyTorch) layout. Flax is (kh, kw, C_in, C_out) (HWIO).
  • Forgetting x[None, ...]. conv_general_dilated requires a batched input. The error is cryptic.
  • Off-by-one with stride>1. Output spatial size is ceil(H/stride), not H // stride.
  • Bare self.kernel into conv_general_dilated. The conv primitive doesn’t auto-unwrap; pass self.kernel.value.

Problem

Write conv2d_forward(seed, x, features, kernel_h, kernel_w, stride):

  1. Define MyConv2D(nnx.Module) with kernel nnx.Param (kernel_h, kernel_w, in_features, out_features) (HWIO, lecun-normal), bias nnx.Param (out_features,) (zeros), and self.stride = stride as a static attribute.
  2. __call__: add batch dim, run conv_general_dilated with window_strides=(stride, stride), padding="SAME", dimension_numbers=("NHWC", "HWIO", "NHWC"). Drop batch, add bias.
  3. Build with nnx.Rngs(int(seed)), instantiate with all dims cast to int, return model(x).reshape(-1).

Inputs:

  • seed: int (passed as float).
  • x: 3-D JAX array (H, W, C_in).
  • features, kernel_h, kernel_w, stride: ints (passed as floats).

Output: 1-D array (flattened conv output).

Hints

flax nnx conv2d reimplementation

Sign in to attempt this problem and view the solution.