We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Deep Nesting
Why this matters
Real models nest several layers deep. A Transformer block contains an attention sub-block which contains four linear projections; an entire Transformer contains N blocks. nnx supports arbitrary nesting via the same submodule-as-attribute pattern from problem 6 — recursively.
What’s nice: when you nnx.split(model), the resulting state pytree
mirrors the module structure. So if you have
Outer.middle.inner.kernel, the state has nested keys
state['middle']['inner']['kernel']. This makes surgery, debug-printing,
and weight loading trivial — the names map 1:1 to attribute paths.
This problem builds three nesting levels: OuterModule -> MiddleModule -> InnerLinear. You’ll see the structure mirrored in the state pytree
in problem 12.
API: nesting nnx Modules
class InnerLinear(nnx.Module):
def __init__(self, in_features, out_features, rngs):
key = rngs.params()
self.kernel = nnx.Param(
jax.random.normal(key, (in_features, out_features))
* (1.0 / jnp.sqrt(in_features))
)
self.bias = nnx.Param(jnp.zeros((out_features,)))
def __call__(self, x):
return x @ self.kernel + self.bias
class MiddleModule(nnx.Module):
def __init__(self, in_features, hidden, rngs):
self.inner = InnerLinear(in_features, hidden, rngs=rngs)
def __call__(self, x):
return jax.nn.relu(self.inner(x))
class OuterModule(nnx.Module):
def __init__(self, in_features, hidden, rngs):
self.middle = MiddleModule(in_features, hidden, rngs=rngs)
def __call__(self, x):
return self.middle(x)
model = OuterModule(in_features=4, hidden=8, rngs=nnx.Rngs(0))
print(model.middle.inner.kernel.value.shape) # (4, 8)
Pass the SAME rngs instance through every level. Since Rngs splits on
each draw, descendants get independent keys without ever sharing one.
How nesting interacts with state
nnx.split(model) returns (graphdef, state) where state is a nested
pytree:
graphdef, state = nnx.split(model)
# state ~= {
# 'middle': {
# 'inner': {
# 'kernel': nnx.Param(...),
# 'bias': nnx.Param(...),
# }
# }
# }
state['middle']['inner']['kernel'].value.shape # (4, 8)
The graphdef captures the static shape of the tree (which submodules exist where); state holds the dynamic values. We’ll work with this in problems 12 and 14.
Common pitfalls
-
Re-using a key across levels. If you grab
key = rngs.params()once at the outer level and pass it as a key to children, children that pullrngs.params()themselves will get a different key from that one — not the duplicate you may have feared. Just passrngs, not a key. -
Inconsistent activation placement. The middle module wraps its
inner with
relu. The outer just calls middle. Don’t double-relu. -
Forgetting
biasinInnerLinear. The reference uses bias=0; omitting it would still type-check but give wrong outputs. -
Naming attributes inconsistently.
self.middle(notself.mid, notself._middle) — the test isn’t sensitive to attribute names, but for readability and downstream compatibility, prefer descriptive names.
Problem
Write deep_nested_forward(seed, x, hidden):
-
Define
InnerLinear(nnx.Module): kernel + biasnnx.Param, forwardx @ kernel + bias. Init kernel withjax.random.normal(key, (in_features, out_features)) * (1 / sqrt(in_features)). -
Define
MiddleModule(nnx.Module): contains oneself.inner = InnerLinear(...). Forward returnsrelu(self.inner(x)). -
Define
OuterModule(nnx.Module): contains oneself.middle = MiddleModule(...). Forward returnsself.middle(x). -
Build
nnx.Rngs(int(seed)), instantiateOuterModule(in_features=x.shape[-1], hidden=int(hidden), rngs=rngs), return the forward output.
The result is relu(x @ inner_kernel + inner_bias), length hidden.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
hidden: int (passed as float).
Output: 1-D array of length hidden.
Hints
Sign in to attempt this problem and view the solution.