medium primitives

NNX Module Basics

Why this matters

flax.nnx is Flax’s modern object-oriented API. It is the recommended starting point for new Flax projects in 2025+. Unlike Flax Linen — where modules are pure blueprints and parameters live OUTSIDE the module in a params dict — nnx modules own their parameters as attributes. The model IS the forward function; construction is initialization, calling is the forward pass.

If you already know Linen, mentally collapse the model.init(key, x) / model.apply(params, x) round-trip into one step. If you are coming from PyTorch, the mental model is closer to home: a module is a Python object whose attributes hold its weights, and you call it like a function. The new wrinkle over PyTorch is that the weights are JAX arrays wrapped in nnx.Param — small metadata wrappers that let nnx and JAX cooperate on tracing, vmapping, and sharding.

This first problem builds the simplest possible nnx module: one parameter, one line of forward code. Internalize the construction shape; everything else in the track builds on it.

API: building a minimal nnx.Module

Three ingredients:

  1. Subclass nnx.Module. Don’t add a metaclass, don’t decorate the class. Just inherit.
  2. Allocate parameters in __init__. Each parameter is a regular Python attribute holding an nnx.Param(...) value. nnx.Param is a thin wrapper around a JAX array; it tells nnx “this is a trainable weight.”
  3. Define __call__. Inside it, read parameters via attribute access. The wrapper unwraps automatically when you use it in math, but you can also write self.kernel.value to be explicit.
class MyModule(nnx.Module):
    def __init__(self, in_features, out_features, rngs):
        key = rngs.params()
        self.kernel = nnx.Param(
            jax.random.normal(key, (in_features, out_features))
            * (1.0 / jnp.sqrt(in_features))
        )

    def __call__(self, x):
        return x @ self.kernel

rngs is an nnx.Rngs instance — a state container that hands out PRNG keys on demand. Each call to rngs.params() returns a fresh key (it splits internally), which is why nnx code never calls jax.random.split manually for parameter init.

Worked example

rngs = nnx.Rngs(0)                        # seeded once
model = MyModule(in_features=3, out_features=4, rngs=rngs)
print(model.kernel.value.shape)           # (3, 4)
x = jnp.array([1.0, 2.0, 3.0])
y = model(x)                              # shape (4,)

Compare with the equivalent Linen flow:

# Linen — for contrast only
class MyDense(nn.Module):
    features: int
    @nn.compact
    def __call__(self, x):
        return nn.Dense(features=self.features, use_bias=False)(x)

model = MyDense(features=4)
params = model.init(jax.random.PRNGKey(0), x)   # extra step
y = model.apply(params, x)                       # extra step

Two API hops collapse into one.

Common pitfalls

  • Forgetting rngs. Module __init__ MUST take an nnx.Rngs argument and use it to draw keys; the seed is what makes init reproducible. (You can omit it for parameter-less modules, but that’s rare.)
  • rngs.params() not rngs. The method call returns the key; the bare object does not.
  • Wrapping jnp.zeros(...) directly is fine for biases or buffers, but training parameters must be wrapped in nnx.Param(...) so optimizers and nnx.state(model, nnx.Param) see them.
  • Don’t call init or apply. They don’t exist on nnx modules. Build and call.
  • int(features): the harness passes numeric inputs as floats — cast shape arguments to int before using them.

Problem

Write module_basics_forward(seed, x, features):

  1. Define an nnx.Module subclass with one trainable parameter kernel of shape (in_features, out_features). Initialize it from a single rngs.params() key as jax.random.normal * (1 / sqrt(in_features)).
  2. The forward pass returns x @ self.kernel. No bias.
  3. Build nnx.Rngs(int(seed)), instantiate the module with in_features=x.shape[-1] and out_features=int(features), and return the output of calling the module on x.

Inputs:

  • seed: int (passed as float — cast to int).
  • x: 1-D JAX array.
  • features: int (passed as float — cast to int).

Output: 1-D array of length features.

Hints

flax nnx module

Sign in to attempt this problem and view the solution.