We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Implement Conv2D
Why this matters
Conv2D is the workhorse of vision. ResNets, U-Nets, EfficientNets, ViT
patch embeddings — all are stacks of Conv2D → norm → activation with
strides for downsampling. Reimplementing it in nnx is the same exercise
you’d do in Linen, with one fewer layer of ceremony: parameters live on
the module, calling runs the forward.
The math is identical to Conv1D — just one more spatial axis. If the last problem made sense, this one should feel mechanical.
Input layout
For 2-D conv, Flax convention is channels-LAST:
-
Input
x:(H, W, C_in)— height, width, in-channels. -
Output
y:(H_out, W_out, C_out).
Add a batch dim → (1, H, W, C_in) for conv_general_dilated.
PyTorch is NCHW (channels first). Flax/JAX uses NHWC because XLA
prefers it for memory layout on TPU and modern GPUs. PyTorch’s
“channels-last” mode is exactly this layout, but PyTorch defaults to
NCHW and reorders under the hood for many GPU kernels — Flax just
decided “channels-last, always” and avoided the entire class of
performance gotchas.
Kernel layout: HWIO
The 2-D conv kernel has shape (kh, kw, C_in, C_out):
-
kh,kw: kernel height and width. -
C_in: input channels. -
C_out: output channels.
Stride and SAME padding
With padding="SAME" and stride=s:
-
Output spatial size =
ceil(H / s)(and likewise for W). -
JAX picks padding amounts so the output is exactly
H_out * s ≈ H.
With stride=1, output is same size as input. With stride=2, output
is roughly half (ceil(H/2)).
Examples for input H=3, W=3:
-
kernel
3x3,stride=1,SAME→ output(3, 3, ...). -
kernel
3x3,stride=2,SAME→ output(2, 2, ...).
Worked sketch
class MyConv2D(nnx.Module):
def __init__(self, in_features, out_features, kernel_h, kernel_w,
stride, rngs):
key = rngs.params()
init = jax.nn.initializers.lecun_normal()
self.kernel = nnx.Param(
init(key, (kernel_h, kernel_w, in_features, out_features))
)
self.bias = nnx.Param(jnp.zeros((out_features,)))
self.stride = stride # plain attribute, not a Param
def __call__(self, x):
x_b = x[jnp.newaxis, ...] # (1, H, W, C_in)
y = jax.lax.conv_general_dilated(
x_b,
self.kernel.value,
window_strides=(self.stride, self.stride),
padding="SAME",
dimension_numbers=("NHWC", "HWIO", "NHWC"),
)
return y[0] + self.bias # (H_out, W_out, C_out)
self.stride = stride (without nnx.Param or nnx.Variable) is a
static attribute — not part of the module’s state pytree, but
perfectly fine for fixed hyperparameters like stride or padding mode.
Why nnx makes this clean
Linen’s nn.Conv is a wrapper around conv_general_dilated. nnx’s
nnx.Conv is the same wrapper. When you write your own conv layer in
nnx, you reach for conv_general_dilated directly — and the rest of
the module is just attribute assignments. No @nn.compact, no
self.param (those are Linen idioms).
Common pitfalls
-
(C_in, C_out, kh, kw)(PyTorch) layout. Flax is(kh, kw, C_in, C_out)(HWIO). -
Forgetting
x[None, ...].conv_general_dilatedrequires a batched input. The error is cryptic. -
Off-by-one with stride>1. Output spatial size is
ceil(H/stride), notH // stride. -
Bare
self.kernelintoconv_general_dilated. The conv primitive doesn’t auto-unwrap; passself.kernel.value.
Problem
Write conv2d_forward(seed, x, features, kernel_h, kernel_w, stride):
-
Define
MyConv2D(nnx.Module)with kernelnnx.Param(kernel_h, kernel_w, in_features, out_features)(HWIO, lecun-normal), biasnnx.Param(out_features,)(zeros), andself.stride = strideas a static attribute. -
__call__: add batch dim, runconv_general_dilatedwithwindow_strides=(stride, stride),padding="SAME",dimension_numbers=("NHWC", "HWIO", "NHWC"). Drop batch, add bias. -
Build with
nnx.Rngs(int(seed)), instantiate with all dims cast to int, returnmodel(x).reshape(-1).
Inputs:
-
seed: int (passed as float). -
x: 3-D JAX array(H, W, C_in). -
features,kernel_h,kernel_w,stride: ints (passed as floats).
Output: 1-D array (flattened conv output).
Hints
Sign in to attempt this problem and view the solution.