medium primitives

Module with Multiple Named Sub-Modules

Why this matters

Real architectures aren’t a single Dense — they’re encoders, decoders, attention blocks, residual blocks, each owning their own sub-modules. A Transformer block contains an attn, an mlp, and two LayerNorms. A ResNet block contains two Conv2Ds and a BatchNorm.

The way you express this in Flax is to give each sub-module a name in setup() and reference it by attribute (self.encoder(x)) inside __call__. The names you choose appear directly in the params dict — they become the keys you use later for surgery, freezing, and weight loading.

setup() vs @nn.compact for multi-submodule

Either works, but setup() is preferred when:

  • You have multiple named sub-modules and want explicit control over names.
  • A sub-module is referenced more than once (you want to reuse it).
  • You want the sub-modules to feel like attributes (self.encoder.features).

@nn.compact is fine for one-off chains; setup() shines when structure matters.

Worked example: encoder/decoder

import jax
import jax.numpy as jnp
import flax.linen as nn

class EncoderDecoder(nn.Module):
    hidden: int
    out: int

    def setup(self):
        self.encoder = nn.Dense(features=self.hidden, name="encoder")
        self.decoder = nn.Dense(features=self.out, name="decoder")

    def __call__(self, x):
        h = jax.nn.relu(self.encoder(x))
        return self.decoder(h)

model = EncoderDecoder(hidden=4, out=3)
key = jax.random.PRNGKey(0)
x = jnp.array([1.0, 2.0, 3.0])
params = model.init(key, x)

# The params dict mirrors the names you chose:
print(jax.tree_util.tree_map(jnp.shape, params))
# {'params': {
#     'encoder': {'kernel': (3, 4), 'bias': (4,)},
#     'decoder': {'kernel': (4, 3), 'bias': (3,)},
# }}

Notice: the keys in params["params"] are exactly "encoder" and "decoder" — the names you passed to the sub-Modules. If you’d left names off, Flax would auto-generate them as Dense_0, Dense_1.

Why explicit names matter

Three concrete reasons:

  1. Param surgery: params["params"]["encoder"] is a stable handle. Auto- names like Dense_0 shift if you re-order sub-modules — fragile.
  2. Weight loading: when porting weights from a pretrained checkpoint, you map source names to your sub-module names. Stable explicit names make the mapping trivial.
  3. Freezing / partial training: optax.masked selects parameters by name — you want those names to be meaningful (encoder not Dense_3).

Common pitfalls

  • Forgetting name=...: self.encoder = nn.Dense(features=self.hidden) still works (the attribute name “encoder” becomes the key by default), but passing name="encoder" is explicit and survives refactors.
  • Defining sub-modules outside setup: don’t do self.encoder = nn.Dense(...) in __init__. Flax expects sub-module definitions inside setup() so the framework can walk the tree.
  • Mutating in __call__: never assign self.foo = something inside __call__. Sub-modules go in setup(), full stop.

Problem

Build an EncoderDecoder Module with two Dense sub-modules:

  • encoder: Dense(features=hidden)
  • decoder: Dense(features=out_size) where out_size = x.shape[0] (i.e., output dim equals input dim — encoder/decoder symmetry).

Forward pass: decoder(relu(encoder(x))). Use setup() to name the sub-modules.

Inputs:

  • seed: int seed (passed as float — cast to int).
  • x: 1-D JAX array.
  • hidden: int (cast to int) — encoder output size.

Output: 1-D array, same length as x.

Hints

flax submodule setup

Sign in to attempt this problem and view the solution.