We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Module That Branches on a Config Flag
Why this matters
Real architectures are configurable. A Transformer might have 6 layers in the “small” config and 24 in the “large” config. A vision model might add an extra convolutional stage when run at high resolution. The structure isn’t fixed at coding time — it’s controlled by a config field.
Flax handles this naturally: a Module is a Python class, and __call__ is a
Python method. Python if on a Module attribute (a hyperparameter) is fine
— at trace time the flag is concrete, the branch is chosen, and the resulting
forward graph is fixed. The params dict reflects whichever branch was taken.
What you must NOT do is if on the input data (a Tracer). That branches on
a value JAX can’t see at trace time and raises a concretization error. (We’ll
cover that pitfall extensively in later problems.) Branching on hyperparams is
the safe, correct pattern.
init reflects the chosen branch
The params dict only contains parameters for the branch that was taken during
init. If you init with use_extra=False, the extra layer’s params are not
allocated.
class ConditionalModel(nn.Module):
hidden: int
use_extra: bool
@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.hidden)(x)
x = jax.nn.relu(x)
if self.use_extra:
x = nn.Dense(features=self.hidden)(x)
x = jax.nn.relu(x)
return nn.Dense(features=x.shape[-1])(x)
With use_extra=True and 1-D input of size 3 (hidden=4):
params["params"] = {
"Dense_0": {kernel: (3, 4), bias: (4,)},
"Dense_1": {kernel: (4, 4), bias: (4,)}, # extra
"Dense_2": {kernel: (4, 4), bias: (4,)}, # final, output dim = x.shape[-1] but post-Dense_0 it's 4
}
Wait — what’s x.shape[-1] at the point nn.Dense(features=x.shape[-1]) is
called? It’s the shape AFTER Dense_0 (and the optional Dense_1), so it’s
hidden. The final Dense maps hidden → hidden. (This is intentional in the
problem to keep the output shape stable; in real models you’d typically use
out_features from the config instead.)
With use_extra=False:
params["params"] = {
"Dense_0": {kernel: (3, 4), bias: (4,)},
"Dense_1": {kernel: (4, 4), bias: (4,)}, # final (renamed because Dense_1 not used above)
}
The auto-naming Dense_0, Dense_1, … is in call order within the
actually-executed forward — so the same name can refer to different layers
across configs. This is one of the reasons people prefer setup() for
production: stable named sub-modules.
Worked example
import jax, jax.numpy as jnp
import flax.linen as nn
model_no_extra = ConditionalModel(hidden=4, use_extra=False)
model_extra = ConditionalModel(hidden=4, use_extra=True)
key = jax.random.PRNGKey(0)
x = jnp.array([1.0, 2.0, 3.0])
p_no = model_no_extra.init(key, x) # 2 Dense
p_yes = model_extra.init(key, x) # 3 Dense
print(jax.tree_util.tree_map(jnp.shape, p_no))
print(jax.tree_util.tree_map(jnp.shape, p_yes))
Common pitfalls
-
Branching on a Tracer:
if x[0] > 0:raisesConcretizationTypeErrorat trace time. For data-dependent branches, usejax.lax.condorjnp.where. Branching on hyperparams (Module attrs) is fine. -
Forgetting the flag is bool:
self.use_extrashould be a Python bool, not a JAX array. If your function receives a numeric flag (0.0 / 1.0), cast:bool(flag >= 0.5). -
Re-using params across configs: don’t pass
paramsfrom ause_extra=Falseinit into ause_extra=Trueapply — keys won’t match.
Problem
Build ConditionalModel(hidden, use_extra) using @nn.compact:
-
Dense(features=hidden)→ ReLU -
If
use_extrais True:Dense(features=hidden)→ ReLU -
Dense(features=x.shape[-1])(read the shape AT THIS POINT in the chain — it’shiddenafter step 1 or 2)
The function takes use_extra_layer as a float (0.0 or 1.0); cast to bool
via bool(use_extra_layer >= 0.5).
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
use_extra_layer: float (0.0 or 1.0). -
hidden: int (passed as float).
Output: 1-D array, length hidden.
Hints
Sign in to attempt this problem and view the solution.