We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Module with setup() (alternative to compact)
Why this matters
Flax provides two patterns for defining a Module’s submodules and forward logic:
@nn.compact and setup(). While @nn.compact is the recommended default for
simple cases, setup() is the right choice in several real-world scenarios — and
understanding both is essential for reading and writing production Flax code.
The two patterns side by side
@nn.compact — declare and call submodules inline in __call__:
class CompactModel(nn.Module):
features: int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.features)(x) # declare Dense here
return jax.nn.relu(x)
setup() — declare submodules in setup, call them in __call__:
class SetupModel(nn.Module):
features: int
def setup(self):
self.dense = nn.Dense(self.features) # declare here
def __call__(self, x):
x = self.dense(x) # call here
return jax.nn.relu(x)
Both produce identical results for the same seed — EXCEPT that the parameter
dict keys differ. In compact, the Dense is auto-named Dense_0. In setup,
the Dense is named dense (the attribute you assigned in setup). This makes
the two parameter dicts structurally different:
# @nn.compact version
params["params"]["Dense_0"]["kernel"]
# setup() version
params["params"]["dense"]["kernel"]
Because the keys differ, the outputs for the same seed will also differ (the kernel values are drawn from the same RNG but stored under different names in the pytree traversal order).
When to use setup()
Use setup() instead of @nn.compact when:
-
Multiple methods share the same submodule — e.g. an encoder used in both
encode()and__call__. With compact, you can’t share: each call tonn.Dense(...)inside compact creates a new layer. With setup,self.encis shared across all methods. -
Custom submodule names — setup always names layers by the attribute name you assign. This is more readable and stable than auto-numbered
Dense_0. -
Conditional or complex init logic — setup runs once on every
initandapplycall, but executes eagerly (not traced), so you can branch on hyperparameters. -
Readability at scale — for modules with 5+ submodules, setup + call is cleaner than a long compact
__call__.
Important rule: don’t mix them
You cannot use both setup() and @nn.compact in the same Module. Flax
will raise an error. Choose one style per Module.
Worked example comparing parameter keys
import jax
import flax.linen as nn
class CompactM(nn.Module):
features: int
@nn.compact
def __call__(self, x):
return nn.Dense(self.features)(x)
class SetupM(nn.Module):
features: int
def setup(self):
self.dense = nn.Dense(self.features)
def __call__(self, x):
return self.dense(x)
key = jax.random.PRNGKey(0)
x = jnp.ones((4,))
c_params = CompactM(features=2).init(key, x)
s_params = SetupM(features=2).init(key, x)
print(c_params) # {'params': {'Dense_0': {'bias': ..., 'kernel': ...}}}
print(s_params) # {'params': {'dense': {'bias': ..., 'kernel': ...}}}
Note: even with the same seed and inputs, the numerical values may differ because
Flax uses the parameter name as part of the RNG fold-in when initializing. The
key 'Dense_0' vs 'dense' produces different random draws.
Common pitfalls
-
Forgetting to assign in setup:
nn.Dense(self.features)in setup withoutself.X = ...does nothing — the layer is created but immediately discarded. -
Using compact and setup together: raises
FlaxError. Pick one. - Expecting the same output as compact: the param names differ, so the RNG draws differ too. Don’t assume numerical equivalence.
-
Calling
super().__init__: Flax Modules use a custom__init_subclass__and Python dataclass machinery. Never callsuper().__init__()insetup.
Problem
Build a single-layer Flax Module using setup() (not @nn.compact) that wraps
one nn.Dense layer. Assign it as self.dense in setup(). Init and apply.
Inputs:
-
seed: int (passed as float — cast to int). -
x: 1-D JAX array. -
features: int (cast to int).
Output: 1-D array of length features — the Dense layer output.
Hints
Sign in to attempt this problem and view the solution.