medium primitives

NNX Multi State Types

Why this matters

A serious model carries SEVERAL kinds of state at once:

  • nnx.Param — trainable kernel and bias (Optax-managed).
  • BatchStat (custom subclass of nnx.Variable) — running mean/var for BatchNorm.
  • Plain nnx.Variable — counters, gating temperatures, debug stats.
  • nnx.Cache — KV cache entries for fast attention generation.

Each piece needs separate handling: optax updates only Params; the serializer might shard cache + batch stats per-device while broadcasting Params; the train step zeros only the cache when restarting generation.

nnx supports this via the nnx.state(model, FilterType) API, with composable filters for fine-grained selection. This problem stitches them all together into one bookkeeping example — a model with FOUR distinct kinds of state and a clean way to count each.

Compare with Linen’s “variable collections”: you’d register each namespace separately (params, batch_stats, cache) and pass the tree as a dict of dicts. nnx replaces it with a type system, which is more flexible and less error-prone.

API: composable filters

For “exactly Variable, no subclasses” you compose:

nnx.All(nnx.Variable, nnx.Not(nnx.Param), nnx.Not(BatchStat), nnx.Not(nnx.Cache))

nnx.All(A, B, C) matches leaves whose Variable subclass passes ALL the sub-filters. nnx.Not(X) excludes leaves of type X.

Other handy combinators:

  • nnx.Any(A, B) — match either.
  • nnx.Not(A) — match everything except A.
  • nnx.Everything() / nnx.Nothing() — sentinels for “all” / “none”.

The most common pattern in real code is just type-by-type — nnx.Param, nnx.BatchStat, nnx.Cache — relying on the subclass hierarchy. The All/Not composition is for the unusual case of “exactly this type no subclasses.”

Worked example

class BatchStat(nnx.Variable):
    pass

class MultiState(nnx.Module):
    def __init__(self, in_features, hidden, rngs):
        key = rngs.params()
        self.kernel = nnx.Param(jax.random.normal(key, (in_features, hidden)) * 0.1)
        self.bias = nnx.Param(jnp.zeros((hidden,)))
        self.running_mean = BatchStat(jnp.zeros((hidden,)))
        self.running_var = BatchStat(jnp.ones((hidden,)))
        self.step_count = nnx.Variable(jnp.array(0.0))
        self.cache = nnx.Cache(jnp.zeros((4, hidden)))
    def __call__(self, x):
        return x @ self.kernel + self.bias

m = MultiState(3, 5, rngs=nnx.Rngs(0))

# Total state: 6 leaves (kernel, bias, mean, var, step, cache).
print(len(jax.tree_util.tree_leaves(nnx.state(m))))             # 6

# Filter by Param: 2 leaves (kernel, bias).
print(len(jax.tree_util.tree_leaves(nnx.state(m, nnx.Param))))  # 2

# Filter by BatchStat: 2 leaves (running_mean, running_var).
print(len(jax.tree_util.tree_leaves(nnx.state(m, BatchStat))))  # 2

# Plain Variable (no subclasses): 1 leaf (step_count).
plain = nnx.All(nnx.Variable, nnx.Not(nnx.Param), nnx.Not(BatchStat), nnx.Not(nnx.Cache))
print(len(jax.tree_util.tree_leaves(nnx.state(m, plain))))      # 1

Notice the step_count is wrapped as nnx.Variable (the base class). Filtering by nnx.Variable would match ALL of them (Param subclasses Variable). To select only “exact” Variables, exclude every subclass you’ve used.

Common pitfalls

  • nnx.state(model, nnx.Variable) returning everything. That’s working as designed (Variable is the root of the hierarchy). For “plain Variable only,” use the All(... Not ...) pattern.
  • Forgetting to Not(nnx.Cache) when pulling plain Variables. Cache is a Variable subclass.
  • Counting model.attr instead of tree_leaves(state). State pytrees have nested structure; only tree_leaves flattens.
  • Initializing running_var to zeros. That’d be numerically wrong (division by zero). Convention: var starts at 1.

Problem

Write multi_state_filter(seed, x, hidden, eps):

  1. Define class BatchStat(nnx.Variable): pass at module level.
  2. Define MultiState(nnx.Module) with six attributes:
    • self.kernel = nnx.Param(...) shape (in_features, hidden), init normal * (1/sqrt(in_features)).
    • self.bias = nnx.Param(jnp.zeros((hidden,))).
    • self.running_mean = BatchStat(jnp.zeros((hidden,))).
    • self.running_var = BatchStat(jnp.ones((hidden,))).
    • self.step_count = nnx.Variable(jnp.array(0.0)).
    • self.cache = nnx.Cache(jnp.zeros((4, hidden))).
    • __call__: x @ self.kernel + self.bias.
  3. Build nnx.Rngs(int(seed)), instantiate the module (in_features=x.shape[-1], hidden=int(hidden)).
  4. Compute three counts:
    • params = nnx.state(model, nnx.Param) → 2 leaves.
    • batch_stats = nnx.state(model, BatchStat) → 2 leaves.
    • plain_vars = nnx.state(model, nnx.All(nnx.Variable, nnx.Not(nnx.Param), nnx.Not(BatchStat), nnx.Not(nnx.Cache))) → 1 leaf.
  5. Return jnp.array([float(param_leaves), float(bs_leaves), float(plain_leaves)]). (eps is unused; multiply + 0.0 * float(eps) to keep it referenced.)

Expected: [2.0, 2.0, 1.0] regardless of seed/shapes/eps.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • hidden: int (passed as float).
  • eps: float (unused).

Output: length-3 array [2.0, 2.0, 1.0].

Hints

flax nnx filter multi-collection

Sign in to attempt this problem and view the solution.