We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Module with Multiple Named Sub-Modules
Why this matters
Real architectures aren’t a single Dense — they’re encoders, decoders, attention
blocks, residual blocks, each owning their own sub-modules. A Transformer block
contains an attn, an mlp, and two LayerNorms. A ResNet block contains two
Conv2Ds and a BatchNorm.
The way you express this in Flax is to give each sub-module a name in
setup() and reference it by attribute (self.encoder(x)) inside __call__.
The names you choose appear directly in the params dict — they become the keys
you use later for surgery, freezing, and weight loading.
setup() vs @nn.compact for multi-submodule
Either works, but setup() is preferred when:
- You have multiple named sub-modules and want explicit control over names.
- A sub-module is referenced more than once (you want to reuse it).
-
You want the sub-modules to feel like attributes (
self.encoder.features).
@nn.compact is fine for one-off chains; setup() shines when structure
matters.
Worked example: encoder/decoder
import jax
import jax.numpy as jnp
import flax.linen as nn
class EncoderDecoder(nn.Module):
hidden: int
out: int
def setup(self):
self.encoder = nn.Dense(features=self.hidden, name="encoder")
self.decoder = nn.Dense(features=self.out, name="decoder")
def __call__(self, x):
h = jax.nn.relu(self.encoder(x))
return self.decoder(h)
model = EncoderDecoder(hidden=4, out=3)
key = jax.random.PRNGKey(0)
x = jnp.array([1.0, 2.0, 3.0])
params = model.init(key, x)
# The params dict mirrors the names you chose:
print(jax.tree_util.tree_map(jnp.shape, params))
# {'params': {
# 'encoder': {'kernel': (3, 4), 'bias': (4,)},
# 'decoder': {'kernel': (4, 3), 'bias': (3,)},
# }}
Notice: the keys in params["params"] are exactly "encoder" and "decoder"
— the names you passed to the sub-Modules. If you’d left names off, Flax would
auto-generate them as Dense_0, Dense_1.
Why explicit names matter
Three concrete reasons:
-
Param surgery:
params["params"]["encoder"]is a stable handle. Auto- names likeDense_0shift if you re-order sub-modules — fragile. - Weight loading: when porting weights from a pretrained checkpoint, you map source names to your sub-module names. Stable explicit names make the mapping trivial.
-
Freezing / partial training:
optax.maskedselects parameters by name — you want those names to be meaningful (encodernotDense_3).
Common pitfalls
-
Forgetting
name=...:self.encoder = nn.Dense(features=self.hidden)still works (the attribute name “encoder” becomes the key by default), but passingname="encoder"is explicit and survives refactors. -
Defining sub-modules outside setup: don’t do
self.encoder = nn.Dense(...)in__init__. Flax expects sub-module definitions insidesetup()so the framework can walk the tree. -
Mutating in
__call__: never assignself.foo = somethinginside__call__. Sub-modules go insetup(), full stop.
Problem
Build an EncoderDecoder Module with two Dense sub-modules:
-
encoder:Dense(features=hidden) -
decoder:Dense(features=out_size)whereout_size = x.shape[0](i.e., output dim equals input dim — encoder/decoder symmetry).
Forward pass: decoder(relu(encoder(x))). Use setup() to name the sub-modules.
Inputs:
-
seed: int seed (passed as float — cast to int). -
x: 1-D JAX array. -
hidden: int (cast to int) — encoder output size.
Output: 1-D array, same length as x.
Hints
Sign in to attempt this problem and view the solution.