medium primitives

NNX Deep Nesting

Why this matters

Real models nest several layers deep. A Transformer block contains an attention sub-block which contains four linear projections; an entire Transformer contains N blocks. nnx supports arbitrary nesting via the same submodule-as-attribute pattern from problem 6 — recursively.

What’s nice: when you nnx.split(model), the resulting state pytree mirrors the module structure. So if you have Outer.middle.inner.kernel, the state has nested keys state['middle']['inner']['kernel']. This makes surgery, debug-printing, and weight loading trivial — the names map 1:1 to attribute paths.

This problem builds three nesting levels: OuterModule -> MiddleModule -> InnerLinear. You’ll see the structure mirrored in the state pytree in problem 12.

API: nesting nnx Modules

class InnerLinear(nnx.Module):
    def __init__(self, in_features, out_features, rngs):
        key = rngs.params()
        self.kernel = nnx.Param(
            jax.random.normal(key, (in_features, out_features))
            * (1.0 / jnp.sqrt(in_features))
        )
        self.bias = nnx.Param(jnp.zeros((out_features,)))
    def __call__(self, x):
        return x @ self.kernel + self.bias

class MiddleModule(nnx.Module):
    def __init__(self, in_features, hidden, rngs):
        self.inner = InnerLinear(in_features, hidden, rngs=rngs)
    def __call__(self, x):
        return jax.nn.relu(self.inner(x))

class OuterModule(nnx.Module):
    def __init__(self, in_features, hidden, rngs):
        self.middle = MiddleModule(in_features, hidden, rngs=rngs)
    def __call__(self, x):
        return self.middle(x)

model = OuterModule(in_features=4, hidden=8, rngs=nnx.Rngs(0))
print(model.middle.inner.kernel.value.shape)   # (4, 8)

Pass the SAME rngs instance through every level. Since Rngs splits on each draw, descendants get independent keys without ever sharing one.

How nesting interacts with state

nnx.split(model) returns (graphdef, state) where state is a nested pytree:

graphdef, state = nnx.split(model)
# state ~= {
#     'middle': {
#         'inner': {
#             'kernel': nnx.Param(...),
#             'bias': nnx.Param(...),
#         }
#     }
# }
state['middle']['inner']['kernel'].value.shape    # (4, 8)

The graphdef captures the static shape of the tree (which submodules exist where); state holds the dynamic values. We’ll work with this in problems 12 and 14.

Common pitfalls

  • Re-using a key across levels. If you grab key = rngs.params() once at the outer level and pass it as a key to children, children that pull rngs.params() themselves will get a different key from that one — not the duplicate you may have feared. Just pass rngs, not a key.
  • Inconsistent activation placement. The middle module wraps its inner with relu. The outer just calls middle. Don’t double-relu.
  • Forgetting bias in InnerLinear. The reference uses bias=0; omitting it would still type-check but give wrong outputs.
  • Naming attributes inconsistently. self.middle (not self.mid, not self._middle) — the test isn’t sensitive to attribute names, but for readability and downstream compatibility, prefer descriptive names.

Problem

Write deep_nested_forward(seed, x, hidden):

  1. Define InnerLinear(nnx.Module): kernel + bias nnx.Param, forward x @ kernel + bias. Init kernel with jax.random.normal(key, (in_features, out_features)) * (1 / sqrt(in_features)).
  2. Define MiddleModule(nnx.Module): contains one self.inner = InnerLinear(...). Forward returns relu(self.inner(x)).
  3. Define OuterModule(nnx.Module): contains one self.middle = MiddleModule(...). Forward returns self.middle(x).
  4. Build nnx.Rngs(int(seed)), instantiate OuterModule(in_features=x.shape[-1], hidden=int(hidden), rngs=rngs), return the forward output.

The result is relu(x @ inner_kernel + inner_bias), length hidden.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • hidden: int (passed as float).

Output: 1-D array of length hidden.

Hints

flax nnx nesting composition

Sign in to attempt this problem and view the solution.