We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Module With Submodule
Why this matters
Real models are made of layers, and layers are made of layers. The composition pattern in nnx is submodules as attributes — exactly like PyTorch:
self.linear1 = nnx.Linear(in_features, hidden, rngs=rngs)
self.linear2 = nnx.Linear(hidden, out, rngs=rngs)
Once a submodule is assigned as an attribute, its parameters become part of
the outer module’s state pytree automatically. nnx.split(outer) returns a
nested state with {linear1: {kernel, bias}, linear2: {kernel, bias}} — no
extra registration code needed.
Compare with Linen’s @nn.compact style, where submodules are declared
INSIDE __call__:
# Linen
@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.hidden)(x) # auto-named Dense_0
x = nn.Dense(features=self.out)(x) # auto-named Dense_1
nnx’s attribute-based composition is more explicit about names and lifetimes;
the layer is constructed once in __init__, exists for the life of the
outer module, and you can debug-print or surgery-edit it directly
(outer.linear1.kernel.value = ...).
API: nnx.Linear (built-in)
Flax ships an nnx.Linear that’s ready to use. Signature:
nnx.Linear(
in_features, # int
out_features, # int
use_bias=True, # default True
rngs=..., # required: nnx.Rngs to seed init
# ... plus dtype/initializer kwargs we'll ignore here
)
By default uses LeCun-normal kernel init and zero bias. Calling it on a
1-D (in_features,) or 2-D (N, in_features) returns a tensor with
out_features along the last axis.
Pass the SAME rngs to both submodules. Each submodule pulls fresh keys
via rngs.params() internally; subsequent calls don’t collide because
Rngs splits on every draw. Be sure to pass rngs (not a key) — submodules
expect the container.
Worked example
class MLP(nnx.Module):
def __init__(self, in_features, hidden, out, rngs):
self.linear1 = nnx.Linear(in_features, hidden, rngs=rngs)
self.linear2 = nnx.Linear(hidden, out, rngs=rngs)
def __call__(self, x):
x = self.linear1(x)
x = jax.nn.relu(x)
return self.linear2(x)
model = MLP(in_features=4, hidden=8, out=2, rngs=nnx.Rngs(0))
print(model.linear1.kernel.value.shape) # (4, 8)
print(model.linear2.kernel.value.shape) # (8, 2)
Common pitfalls
-
Constructing the same submodule twice in a non-deterministic way.
rngs.params()is consumed in order — if you reorder the construction of submodules, you’ll get different parameter values for the same seed. Construct them in the order they’ll be called for clarity. -
Sharing one submodule across two attributes. If you do
self.a = self.b = nnx.Linear(...), both attributes refer to the SAME Python object (parameter sharing — useful sometimes, but not what we want here). -
Forgetting
rngsarg.nnx.Linear(...)withoutrngs=...errors because it can’t initialize the kernel. -
Wrong activation placement. Standard MLP is
linear → activation → linear. The activation is between the two linears, not after the second one (we want a linear last layer).
Problem
Write submodule_compose(seed, x, hidden, out):
-
Define
MLP(nnx.Module)with twonnx.Linearsubmodules:self.linear1 = nnx.Linear(in_features, hidden, rngs=rngs)andself.linear2 = nnx.Linear(hidden, out, rngs=rngs). -
Forward:
x = linear1(x); x = relu(x); return linear2(x). -
Build with
nnx.Rngs(int(seed)), instantiate the MLP (in_features=x.shape[-1],hidden=int(hidden),out=int(out)), return its output onx.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
hidden: int (passed as float). -
out: int (passed as float).
Output: 1-D array of length out.
Hints
Sign in to attempt this problem and view the solution.