We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Three Levels of Module Nesting
Why this matters
Real architectures are nested. A Transformer is a stack of TransformerBlocks,
each containing an Attention and an FFN, each of those containing several
Denses and LayerNorms. Three or four levels of nesting is the norm; ten is
not unusual in production codebases.
Flax handles arbitrary nesting cleanly: a Module instance can hold other Modules as sub-modules, those can hold sub-sub-modules, and so on. The params dict mirrors the structure exactly — one nested dict level per Module level.
How params reflect nesting
For three levels of @nn.compact Modules:
class InnerDense(nn.Module):
features: int
@nn.compact
def __call__(self, x):
return nn.Dense(features=self.features)(x)
class MiddleBlock(nn.Module):
features: int
@nn.compact
def __call__(self, x):
x = InnerDense(features=self.features)(x)
return jax.nn.relu(x)
class OuterModel(nn.Module):
hidden: int
@nn.compact
def __call__(self, x):
x = MiddleBlock(features=self.hidden)(x)
return nn.Dense(features=x.shape[-1])(x)
Init with x.shape == (3,), hidden=4:
params["params"] = {
"MiddleBlock_0": {
"InnerDense_0": {
"Dense_0": {"kernel": (3, 4), "bias": (4,)}
}
},
"Dense_0": {"kernel": (4, 4), "bias": (4,)}
}
Read the structure: OuterModel contains MiddleBlock_0 and Dense_0 (top
level). MiddleBlock_0 contains InnerDense_0. InnerDense_0 contains its
own Dense_0. The dot-separated path is MiddleBlock_0.InnerDense_0.Dense_0.
Auto-naming inside compact
Inside @nn.compact, every Module/sub-module instantiation is auto-named in
call order: Dense_0, Dense_1, MiddleBlock_0, LayerNorm_0, etc. Within
a given Module the counter is per-class. The auto-name in the params dict
matches the call order.
Inside setup(), the assigned attribute name is used: self.foo = nn.Dense(...)
appears as params["params"]["foo"].
Inspecting nested params
# Visualize all shapes:
print(jax.tree_util.tree_map(jnp.shape, params))
# Reach into a specific layer:
inner_kernel = params["params"]["MiddleBlock_0"]["InnerDense_0"]["Dense_0"]["kernel"]
# Pretty-print the tree:
import jax.tree_util as jtu
leaves, treedef = jtu.tree_flatten(params)
print(treedef)
Why this is more than a toy
- When loading pretrained weights, you walk the source dict and emplace into this nested structure. Mismatched nesting is the #1 cause of “wrong shape” errors during weight loading.
-
When freezing only some layers, you build a “freeze mask” of the same
nested structure (1.0 = trainable, 0.0 = frozen) and pass it to
optax.masked. - When sharding, you pair this nested params tree with a parallel “sharding tree” telling JAX where each parameter lives.
Common pitfalls
-
Sub-module not called inside
@nn.compact: just instantiatingMiddleBlock(...)is not enough — you must call it:MiddleBlock(...)(x). Otherwise nothing is added to the params tree. -
Re-using the same instance:
block = MiddleBlock(...)thenblock(x)twice — Flax registers the Module once but each call re-uses the same params (parameter sharing). Often what you want; sometimes not. - Nested setup vs compact mixed sloppily: works, but stick to one style per Module for readability.
Problem
Build the three-level nesting exactly as shown above:
-
InnerDense(features): just aDense(features). -
MiddleBlock(features):InnerDense(features)then ReLU. -
OuterModel(hidden):MiddleBlock(hidden)thenDense(features=x.shape[-1]).
All Modules use @nn.compact. The function inits with PRNGKey(seed) and
applies to x.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
hidden: int (passed as float).
Output: 1-D array, length hidden.
Hints
Sign in to attempt this problem and view the solution.