medium primitives

Three Levels of Module Nesting

Why this matters

Real architectures are nested. A Transformer is a stack of TransformerBlocks, each containing an Attention and an FFN, each of those containing several Denses and LayerNorms. Three or four levels of nesting is the norm; ten is not unusual in production codebases.

Flax handles arbitrary nesting cleanly: a Module instance can hold other Modules as sub-modules, those can hold sub-sub-modules, and so on. The params dict mirrors the structure exactly — one nested dict level per Module level.

How params reflect nesting

For three levels of @nn.compact Modules:

class InnerDense(nn.Module):
    features: int
    @nn.compact
    def __call__(self, x):
        return nn.Dense(features=self.features)(x)

class MiddleBlock(nn.Module):
    features: int
    @nn.compact
    def __call__(self, x):
        x = InnerDense(features=self.features)(x)
        return jax.nn.relu(x)

class OuterModel(nn.Module):
    hidden: int
    @nn.compact
    def __call__(self, x):
        x = MiddleBlock(features=self.hidden)(x)
        return nn.Dense(features=x.shape[-1])(x)

Init with x.shape == (3,), hidden=4:

params["params"] = {
    "MiddleBlock_0": {
        "InnerDense_0": {
            "Dense_0": {"kernel": (3, 4), "bias": (4,)}
        }
    },
    "Dense_0": {"kernel": (4, 4), "bias": (4,)}
}

Read the structure: OuterModel contains MiddleBlock_0 and Dense_0 (top level). MiddleBlock_0 contains InnerDense_0. InnerDense_0 contains its own Dense_0. The dot-separated path is MiddleBlock_0.InnerDense_0.Dense_0.

Auto-naming inside compact

Inside @nn.compact, every Module/sub-module instantiation is auto-named in call order: Dense_0, Dense_1, MiddleBlock_0, LayerNorm_0, etc. Within a given Module the counter is per-class. The auto-name in the params dict matches the call order.

Inside setup(), the assigned attribute name is used: self.foo = nn.Dense(...) appears as params["params"]["foo"].

Inspecting nested params

# Visualize all shapes:
print(jax.tree_util.tree_map(jnp.shape, params))

# Reach into a specific layer:
inner_kernel = params["params"]["MiddleBlock_0"]["InnerDense_0"]["Dense_0"]["kernel"]

# Pretty-print the tree:
import jax.tree_util as jtu
leaves, treedef = jtu.tree_flatten(params)
print(treedef)

Why this is more than a toy

  • When loading pretrained weights, you walk the source dict and emplace into this nested structure. Mismatched nesting is the #1 cause of “wrong shape” errors during weight loading.
  • When freezing only some layers, you build a “freeze mask” of the same nested structure (1.0 = trainable, 0.0 = frozen) and pass it to optax.masked.
  • When sharding, you pair this nested params tree with a parallel “sharding tree” telling JAX where each parameter lives.

Common pitfalls

  • Sub-module not called inside @nn.compact: just instantiating MiddleBlock(...) is not enough — you must call it: MiddleBlock(...)(x). Otherwise nothing is added to the params tree.
  • Re-using the same instance: block = MiddleBlock(...) then block(x) twice — Flax registers the Module once but each call re-uses the same params (parameter sharing). Often what you want; sometimes not.
  • Nested setup vs compact mixed sloppily: works, but stick to one style per Module for readability.

Problem

Build the three-level nesting exactly as shown above:

  • InnerDense(features): just a Dense(features).
  • MiddleBlock(features): InnerDense(features) then ReLU.
  • OuterModel(hidden): MiddleBlock(hidden) then Dense(features=x.shape[-1]).

All Modules use @nn.compact. The function inits with PRNGKey(seed) and applies to x.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • hidden: int (passed as float).

Output: 1-D array, length hidden.

Hints

flax module nesting

Sign in to attempt this problem and view the solution.