medium primitives

NNX Module With Submodule

Why this matters

Real models are made of layers, and layers are made of layers. The composition pattern in nnx is submodules as attributes — exactly like PyTorch:

self.linear1 = nnx.Linear(in_features, hidden, rngs=rngs)
self.linear2 = nnx.Linear(hidden, out, rngs=rngs)

Once a submodule is assigned as an attribute, its parameters become part of the outer module’s state pytree automatically. nnx.split(outer) returns a nested state with {linear1: {kernel, bias}, linear2: {kernel, bias}} — no extra registration code needed.

Compare with Linen’s @nn.compact style, where submodules are declared INSIDE __call__:

# Linen
@nn.compact
def __call__(self, x):
    x = nn.Dense(features=self.hidden)(x)   # auto-named Dense_0
    x = nn.Dense(features=self.out)(x)      # auto-named Dense_1

nnx’s attribute-based composition is more explicit about names and lifetimes; the layer is constructed once in __init__, exists for the life of the outer module, and you can debug-print or surgery-edit it directly (outer.linear1.kernel.value = ...).

API: nnx.Linear (built-in)

Flax ships an nnx.Linear that’s ready to use. Signature:

nnx.Linear(
    in_features,    # int
    out_features,   # int
    use_bias=True,  # default True
    rngs=...,       # required: nnx.Rngs to seed init
    # ... plus dtype/initializer kwargs we'll ignore here
)

By default uses LeCun-normal kernel init and zero bias. Calling it on a 1-D (in_features,) or 2-D (N, in_features) returns a tensor with out_features along the last axis.

Pass the SAME rngs to both submodules. Each submodule pulls fresh keys via rngs.params() internally; subsequent calls don’t collide because Rngs splits on every draw. Be sure to pass rngs (not a key) — submodules expect the container.

Worked example

class MLP(nnx.Module):
    def __init__(self, in_features, hidden, out, rngs):
        self.linear1 = nnx.Linear(in_features, hidden, rngs=rngs)
        self.linear2 = nnx.Linear(hidden, out, rngs=rngs)

    def __call__(self, x):
        x = self.linear1(x)
        x = jax.nn.relu(x)
        return self.linear2(x)

model = MLP(in_features=4, hidden=8, out=2, rngs=nnx.Rngs(0))
print(model.linear1.kernel.value.shape)   # (4, 8)
print(model.linear2.kernel.value.shape)   # (8, 2)

Common pitfalls

  • Constructing the same submodule twice in a non-deterministic way. rngs.params() is consumed in order — if you reorder the construction of submodules, you’ll get different parameter values for the same seed. Construct them in the order they’ll be called for clarity.
  • Sharing one submodule across two attributes. If you do self.a = self.b = nnx.Linear(...), both attributes refer to the SAME Python object (parameter sharing — useful sometimes, but not what we want here).
  • Forgetting rngs arg. nnx.Linear(...) without rngs=... errors because it can’t initialize the kernel.
  • Wrong activation placement. Standard MLP is linear → activation → linear. The activation is between the two linears, not after the second one (we want a linear last layer).

Problem

Write submodule_compose(seed, x, hidden, out):

  1. Define MLP(nnx.Module) with two nnx.Linear submodules: self.linear1 = nnx.Linear(in_features, hidden, rngs=rngs) and self.linear2 = nnx.Linear(hidden, out, rngs=rngs).
  2. Forward: x = linear1(x); x = relu(x); return linear2(x).
  3. Build with nnx.Rngs(int(seed)), instantiate the MLP (in_features=x.shape[-1], hidden=int(hidden), out=int(out)), return its output on x.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • hidden: int (passed as float).
  • out: int (passed as float).

Output: 1-D array of length out.

Hints

flax nnx submodule mlp

Sign in to attempt this problem and view the solution.