We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Module Basics
Why this matters
flax.nnx is Flax’s modern object-oriented API. It is the recommended starting
point for new Flax projects in 2025+. Unlike Flax Linen — where modules are pure
blueprints and parameters live OUTSIDE the module in a params dict — nnx
modules own their parameters as attributes. The model IS the forward function;
construction is initialization, calling is the forward pass.
If you already know Linen, mentally collapse the model.init(key, x) /
model.apply(params, x) round-trip into one step. If you are coming from
PyTorch, the mental model is closer to home: a module is a Python object whose
attributes hold its weights, and you call it like a function. The new wrinkle
over PyTorch is that the weights are JAX arrays wrapped in nnx.Param — small
metadata wrappers that let nnx and JAX cooperate on tracing, vmapping, and
sharding.
This first problem builds the simplest possible nnx module: one parameter, one line of forward code. Internalize the construction shape; everything else in the track builds on it.
API: building a minimal nnx.Module
Three ingredients:
-
Subclass
nnx.Module. Don’t add a metaclass, don’t decorate the class. Just inherit. -
Allocate parameters in
__init__. Each parameter is a regular Python attribute holding annnx.Param(...)value.nnx.Paramis a thin wrapper around a JAX array; it tells nnx “this is a trainable weight.” -
Define
__call__. Inside it, read parameters via attribute access. The wrapper unwraps automatically when you use it in math, but you can also writeself.kernel.valueto be explicit.
class MyModule(nnx.Module):
def __init__(self, in_features, out_features, rngs):
key = rngs.params()
self.kernel = nnx.Param(
jax.random.normal(key, (in_features, out_features))
* (1.0 / jnp.sqrt(in_features))
)
def __call__(self, x):
return x @ self.kernel
rngs is an nnx.Rngs instance — a state container that hands out PRNG keys
on demand. Each call to rngs.params() returns a fresh key (it splits
internally), which is why nnx code never calls jax.random.split manually for
parameter init.
Worked example
rngs = nnx.Rngs(0) # seeded once
model = MyModule(in_features=3, out_features=4, rngs=rngs)
print(model.kernel.value.shape) # (3, 4)
x = jnp.array([1.0, 2.0, 3.0])
y = model(x) # shape (4,)
Compare with the equivalent Linen flow:
# Linen — for contrast only
class MyDense(nn.Module):
features: int
@nn.compact
def __call__(self, x):
return nn.Dense(features=self.features, use_bias=False)(x)
model = MyDense(features=4)
params = model.init(jax.random.PRNGKey(0), x) # extra step
y = model.apply(params, x) # extra step
Two API hops collapse into one.
Common pitfalls
-
Forgetting
rngs. Module__init__MUST take annnx.Rngsargument and use it to draw keys; the seed is what makes init reproducible. (You can omit it for parameter-less modules, but that’s rare.) -
rngs.params()notrngs. The method call returns the key; the bare object does not. -
Wrapping
jnp.zeros(...)directly is fine for biases or buffers, but training parameters must be wrapped innnx.Param(...)so optimizers andnnx.state(model, nnx.Param)see them. -
Don’t call
initorapply. They don’t exist on nnx modules. Build and call. -
int(features): the harness passes numeric inputs as floats — cast shape arguments tointbefore using them.
Problem
Write module_basics_forward(seed, x, features):
-
Define an
nnx.Modulesubclass with one trainable parameterkernelof shape(in_features, out_features). Initialize it from a singlerngs.params()key asjax.random.normal * (1 / sqrt(in_features)). -
The forward pass returns
x @ self.kernel. No bias. -
Build
nnx.Rngs(int(seed)), instantiate the module within_features=x.shape[-1]andout_features=int(features), and return the output of calling the module onx.
Inputs:
-
seed: int (passed as float — cast to int). -
x: 1-D JAX array. -
features: int (passed as float — cast to int).
Output: 1-D array of length features.
Hints
Sign in to attempt this problem and view the solution.