medium primitives

Module with setup() (alternative to compact)

Why this matters

Flax provides two patterns for defining a Module’s submodules and forward logic: @nn.compact and setup(). While @nn.compact is the recommended default for simple cases, setup() is the right choice in several real-world scenarios — and understanding both is essential for reading and writing production Flax code.

The two patterns side by side

@nn.compact — declare and call submodules inline in __call__:

class CompactModel(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.features)(x)    # declare Dense here
        return jax.nn.relu(x)

setup() — declare submodules in setup, call them in __call__:

class SetupModel(nn.Module):
    features: int

    def setup(self):
        self.dense = nn.Dense(self.features)  # declare here

    def __call__(self, x):
        x = self.dense(x)                      # call here
        return jax.nn.relu(x)

Both produce identical results for the same seed — EXCEPT that the parameter dict keys differ. In compact, the Dense is auto-named Dense_0. In setup, the Dense is named dense (the attribute you assigned in setup). This makes the two parameter dicts structurally different:

# @nn.compact version
params["params"]["Dense_0"]["kernel"]

# setup() version
params["params"]["dense"]["kernel"]

Because the keys differ, the outputs for the same seed will also differ (the kernel values are drawn from the same RNG but stored under different names in the pytree traversal order).

When to use setup()

Use setup() instead of @nn.compact when:

  1. Multiple methods share the same submodule — e.g. an encoder used in both encode() and __call__. With compact, you can’t share: each call to nn.Dense(...) inside compact creates a new layer. With setup, self.enc is shared across all methods.

  2. Custom submodule names — setup always names layers by the attribute name you assign. This is more readable and stable than auto-numbered Dense_0.

  3. Conditional or complex init logic — setup runs once on every init and apply call, but executes eagerly (not traced), so you can branch on hyperparameters.

  4. Readability at scale — for modules with 5+ submodules, setup + call is cleaner than a long compact __call__.

Important rule: don’t mix them

You cannot use both setup() and @nn.compact in the same Module. Flax will raise an error. Choose one style per Module.

Worked example comparing parameter keys

import jax
import flax.linen as nn

class CompactM(nn.Module):
    features: int
    @nn.compact
    def __call__(self, x):
        return nn.Dense(self.features)(x)

class SetupM(nn.Module):
    features: int
    def setup(self):
        self.dense = nn.Dense(self.features)
    def __call__(self, x):
        return self.dense(x)

key = jax.random.PRNGKey(0)
x = jnp.ones((4,))

c_params = CompactM(features=2).init(key, x)
s_params = SetupM(features=2).init(key, x)

print(c_params)  # {'params': {'Dense_0': {'bias': ..., 'kernel': ...}}}
print(s_params)  # {'params': {'dense': {'bias': ..., 'kernel': ...}}}

Note: even with the same seed and inputs, the numerical values may differ because Flax uses the parameter name as part of the RNG fold-in when initializing. The key 'Dense_0' vs 'dense' produces different random draws.

Common pitfalls

  • Forgetting to assign in setup: nn.Dense(self.features) in setup without self.X = ... does nothing — the layer is created but immediately discarded.
  • Using compact and setup together: raises FlaxError. Pick one.
  • Expecting the same output as compact: the param names differ, so the RNG draws differ too. Don’t assume numerical equivalence.
  • Calling super().__init__: Flax Modules use a custom __init_subclass__ and Python dataclass machinery. Never call super().__init__() in setup.

Problem

Build a single-layer Flax Module using setup() (not @nn.compact) that wraps one nn.Dense layer. Assign it as self.dense in setup(). Init and apply.

Inputs:

  • seed: int (passed as float — cast to int).
  • x: 1-D JAX array.
  • features: int (cast to int).

Output: 1-D array of length features — the Dense layer output.

Hints

flax module setup

Sign in to attempt this problem and view the solution.