We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Conditional Submodule
Why this matters
Production architectures often have OPTIONAL components: an extra projection
head, an auxiliary classifier, a normalization layer that’s only present in
“deep” variants. The clean way to express this in nnx is branch on a
Python bool inside __init__ — only allocate the optional submodule when
needed.
The crucial rule: branching on Python bools at construction time is fine;
branching on a traced JAX value at runtime in __call__ is not. A
config flag like use_extra is just a Python bool — __init__ runs in
plain Python, before any tracing, so a normal if is OK. Inside a
nnx.jit-ed __call__, however, an if jax_traced_thing: would either
error or silently take the wrong branch.
API: branching at construction time
class MaybeExtra(nnx.Module):
def __init__(self, in_features, hidden, use_extra: bool, rngs):
self.linear1 = nnx.Linear(in_features, hidden, rngs=rngs)
self.use_extra = use_extra # plain Python bool (static)
if use_extra:
self.extra = nnx.Linear(hidden, hidden, rngs=rngs)
self.linear2 = nnx.Linear(hidden, hidden, rngs=rngs)
def __call__(self, x):
x = jax.nn.relu(self.linear1(x))
if self.use_extra: # plain bool, OK in Python
x = jax.nn.relu(self.extra(x))
return self.linear2(x)
Note: self.use_extra here is a plain bool stored as a static attribute
(it doesn’t go through nnx.Param or nnx.Variable). nnx treats raw Python
attributes as part of the static graphdef — they end up in graphdef, not
state, and a freshly-merged module gets the same flag value. We’ll see
this in problem 18.
Why not branch in __call__ only?
You could branch on use_extra in __call__ and skip allocation entirely,
but then you’d have a parameter-less branch that disagrees with the model’s
state structure. Cleaner is: allocate iff used, attribute iff allocated.
# Anti-pattern
if self.use_extra:
# but self.extra was never assigned! AttributeError
x = self.extra(x)
Always pair if use_extra: in __init__ with if self.use_extra: in
__call__.
Cast use_extra from float to bool
The harness passes inputs as floats. Convert with a threshold:
use = bool(use_extra >= 0.5)
This is the standard pattern for boolean flags in this curriculum. (The
Linen track does it the same way — see flax-conditional-submodule.)
Worked behavior
Same seed, same x, different use_extra:
model_no = MaybeExtra(in_features=3, hidden=4, use_extra=False, rngs=nnx.Rngs(0))
model_yes = MaybeExtra(in_features=3, hidden=4, use_extra=True, rngs=nnx.Rngs(0))
print(model_no(x), model_yes(x)) # different — extra=True path consumes more rngs
Even with the same seed, the outputs differ because the use_extra=True
path pulls one more rngs.params() for the extra Linear’s kernel — which
means linear2‘s kernel ends up initialized from a different key than in
the use_extra=False case. (rngs.params() is consumed in order; one
extra draw shifts everything downstream.)
Common pitfalls
-
Branching on a tracer in
__call__. Insidennx.jit,if jax_array > 0is wrong; usejnp.where. -
Forgetting to set
self.use_extraas an attribute. Without it,__call__can’t read the flag. -
Casting too early.
bool(0.5)isTrue(any nonzero float is truthy). Use>= 0.5to align with the >0 threshold. -
Mismatched output shape. Both branches must end with
self.linear2so the output dim is consistent.
Problem
Write conditional_forward(seed, x, hidden, use_extra):
-
Cast
use = bool(use_extra >= 0.5). -
Define
MaybeExtra(nnx.Module)withlinear1andlinear2(bothnnx.Linear(..., hidden, rngs=rngs)). Ifuse_extra=True, also allocateself.extra = nnx.Linear(hidden, hidden, rngs=rngs). Forward:x = relu(linear1(x)); ifuse_extra,x = relu(extra(x)); returnlinear2(x). -
linear1takesin_features=x.shape[-1], outputhidden.linear2ishidden -> hidden. So output length =hidden. -
Build with
nnx.Rngs(int(seed)), instantiate, returnmodel(x).
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
hidden: int (passed as float). -
use_extra: float — 0.0 or 1.0.
Output: 1-D array of length hidden.
Hints
Sign in to attempt this problem and view the solution.