medium primitives

Module That Branches on a Config Flag

Why this matters

Real architectures are configurable. A Transformer might have 6 layers in the “small” config and 24 in the “large” config. A vision model might add an extra convolutional stage when run at high resolution. The structure isn’t fixed at coding time — it’s controlled by a config field.

Flax handles this naturally: a Module is a Python class, and __call__ is a Python method. Python if on a Module attribute (a hyperparameter) is fine — at trace time the flag is concrete, the branch is chosen, and the resulting forward graph is fixed. The params dict reflects whichever branch was taken.

What you must NOT do is if on the input data (a Tracer). That branches on a value JAX can’t see at trace time and raises a concretization error. (We’ll cover that pitfall extensively in later problems.) Branching on hyperparams is the safe, correct pattern.

init reflects the chosen branch

The params dict only contains parameters for the branch that was taken during init. If you init with use_extra=False, the extra layer’s params are not allocated.

class ConditionalModel(nn.Module):
    hidden: int
    use_extra: bool

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden)(x)
        x = jax.nn.relu(x)
        if self.use_extra:
            x = nn.Dense(features=self.hidden)(x)
            x = jax.nn.relu(x)
        return nn.Dense(features=x.shape[-1])(x)

With use_extra=True and 1-D input of size 3 (hidden=4):

params["params"] = {
    "Dense_0": {kernel: (3, 4), bias: (4,)},
    "Dense_1": {kernel: (4, 4), bias: (4,)},   # extra
    "Dense_2": {kernel: (4, 4), bias: (4,)},   # final, output dim = x.shape[-1] but post-Dense_0 it's 4
}

Wait — what’s x.shape[-1] at the point nn.Dense(features=x.shape[-1]) is called? It’s the shape AFTER Dense_0 (and the optional Dense_1), so it’s hidden. The final Dense maps hidden → hidden. (This is intentional in the problem to keep the output shape stable; in real models you’d typically use out_features from the config instead.)

With use_extra=False:

params["params"] = {
    "Dense_0": {kernel: (3, 4), bias: (4,)},
    "Dense_1": {kernel: (4, 4), bias: (4,)},   # final (renamed because Dense_1 not used above)
}

The auto-naming Dense_0, Dense_1, … is in call order within the actually-executed forward — so the same name can refer to different layers across configs. This is one of the reasons people prefer setup() for production: stable named sub-modules.

Worked example

import jax, jax.numpy as jnp
import flax.linen as nn

model_no_extra = ConditionalModel(hidden=4, use_extra=False)
model_extra    = ConditionalModel(hidden=4, use_extra=True)

key = jax.random.PRNGKey(0)
x   = jnp.array([1.0, 2.0, 3.0])

p_no = model_no_extra.init(key, x)   # 2 Dense
p_yes = model_extra.init(key, x)     # 3 Dense

print(jax.tree_util.tree_map(jnp.shape, p_no))
print(jax.tree_util.tree_map(jnp.shape, p_yes))

Common pitfalls

  • Branching on a Tracer: if x[0] > 0: raises ConcretizationTypeError at trace time. For data-dependent branches, use jax.lax.cond or jnp.where. Branching on hyperparams (Module attrs) is fine.
  • Forgetting the flag is bool: self.use_extra should be a Python bool, not a JAX array. If your function receives a numeric flag (0.0 / 1.0), cast: bool(flag >= 0.5).
  • Re-using params across configs: don’t pass params from a use_extra=False init into a use_extra=True apply — keys won’t match.

Problem

Build ConditionalModel(hidden, use_extra) using @nn.compact:

  1. Dense(features=hidden) → ReLU
  2. If use_extra is True: Dense(features=hidden) → ReLU
  3. Dense(features=x.shape[-1]) (read the shape AT THIS POINT in the chain — it’s hidden after step 1 or 2)

The function takes use_extra_layer as a float (0.0 or 1.0); cast to bool via bool(use_extra_layer >= 0.5).

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • use_extra_layer: float (0.0 or 1.0).
  • hidden: int (passed as float).

Output: 1-D array, length hidden.

Hints

flax module config-branching

Sign in to attempt this problem and view the solution.