medium primitives

NNX Conditional Submodule

Why this matters

Production architectures often have OPTIONAL components: an extra projection head, an auxiliary classifier, a normalization layer that’s only present in “deep” variants. The clean way to express this in nnx is branch on a Python bool inside __init__ — only allocate the optional submodule when needed.

The crucial rule: branching on Python bools at construction time is fine; branching on a traced JAX value at runtime in __call__ is not. A config flag like use_extra is just a Python bool — __init__ runs in plain Python, before any tracing, so a normal if is OK. Inside a nnx.jit-ed __call__, however, an if jax_traced_thing: would either error or silently take the wrong branch.

API: branching at construction time

class MaybeExtra(nnx.Module):
    def __init__(self, in_features, hidden, use_extra: bool, rngs):
        self.linear1 = nnx.Linear(in_features, hidden, rngs=rngs)
        self.use_extra = use_extra            # plain Python bool (static)
        if use_extra:
            self.extra = nnx.Linear(hidden, hidden, rngs=rngs)
        self.linear2 = nnx.Linear(hidden, hidden, rngs=rngs)

    def __call__(self, x):
        x = jax.nn.relu(self.linear1(x))
        if self.use_extra:                    # plain bool, OK in Python
            x = jax.nn.relu(self.extra(x))
        return self.linear2(x)

Note: self.use_extra here is a plain bool stored as a static attribute (it doesn’t go through nnx.Param or nnx.Variable). nnx treats raw Python attributes as part of the static graphdef — they end up in graphdef, not state, and a freshly-merged module gets the same flag value. We’ll see this in problem 18.

Why not branch in __call__ only?

You could branch on use_extra in __call__ and skip allocation entirely, but then you’d have a parameter-less branch that disagrees with the model’s state structure. Cleaner is: allocate iff used, attribute iff allocated.

# Anti-pattern
if self.use_extra:
    # but self.extra was never assigned! AttributeError
    x = self.extra(x)

Always pair if use_extra: in __init__ with if self.use_extra: in __call__.

Cast use_extra from float to bool

The harness passes inputs as floats. Convert with a threshold:

use = bool(use_extra >= 0.5)

This is the standard pattern for boolean flags in this curriculum. (The Linen track does it the same way — see flax-conditional-submodule.)

Worked behavior

Same seed, same x, different use_extra:

model_no = MaybeExtra(in_features=3, hidden=4, use_extra=False, rngs=nnx.Rngs(0))
model_yes = MaybeExtra(in_features=3, hidden=4, use_extra=True, rngs=nnx.Rngs(0))
print(model_no(x), model_yes(x))   # different — extra=True path consumes more rngs

Even with the same seed, the outputs differ because the use_extra=True path pulls one more rngs.params() for the extra Linear’s kernel — which means linear2‘s kernel ends up initialized from a different key than in the use_extra=False case. (rngs.params() is consumed in order; one extra draw shifts everything downstream.)

Common pitfalls

  • Branching on a tracer in __call__. Inside nnx.jit, if jax_array > 0 is wrong; use jnp.where.
  • Forgetting to set self.use_extra as an attribute. Without it, __call__ can’t read the flag.
  • Casting too early. bool(0.5) is True (any nonzero float is truthy). Use >= 0.5 to align with the >0 threshold.
  • Mismatched output shape. Both branches must end with self.linear2 so the output dim is consistent.

Problem

Write conditional_forward(seed, x, hidden, use_extra):

  1. Cast use = bool(use_extra >= 0.5).
  2. Define MaybeExtra(nnx.Module) with linear1 and linear2 (both nnx.Linear(..., hidden, rngs=rngs)). If use_extra=True, also allocate self.extra = nnx.Linear(hidden, hidden, rngs=rngs). Forward: x = relu(linear1(x)); if use_extra, x = relu(extra(x)); return linear2(x).
  3. linear1 takes in_features=x.shape[-1], output hidden. linear2 is hidden -> hidden. So output length = hidden.
  4. Build with nnx.Rngs(int(seed)), instantiate, return model(x).

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • hidden: int (passed as float).
  • use_extra: float — 0.0 or 1.0.

Output: 1-D array of length hidden.

Hints

flax nnx conditional submodule

Sign in to attempt this problem and view the solution.