We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Multi State Types
Why this matters
A serious model carries SEVERAL kinds of state at once:
-
nnx.Param— trainable kernel and bias (Optax-managed). -
BatchStat(custom subclass ofnnx.Variable) — running mean/var for BatchNorm. -
Plain
nnx.Variable— counters, gating temperatures, debug stats. -
nnx.Cache— KV cache entries for fast attention generation.
Each piece needs separate handling: optax updates only Params; the serializer might shard cache + batch stats per-device while broadcasting Params; the train step zeros only the cache when restarting generation.
nnx supports this via the nnx.state(model, FilterType) API, with
composable filters for fine-grained selection. This problem stitches
them all together into one bookkeeping example — a model with FOUR
distinct kinds of state and a clean way to count each.
Compare with Linen’s “variable collections”: you’d register each
namespace separately (params, batch_stats, cache) and pass the
tree as a dict of dicts. nnx replaces it with a type system, which
is more flexible and less error-prone.
API: composable filters
For “exactly Variable, no subclasses” you compose:
nnx.All(nnx.Variable, nnx.Not(nnx.Param), nnx.Not(BatchStat), nnx.Not(nnx.Cache))
nnx.All(A, B, C) matches leaves whose Variable subclass passes ALL the
sub-filters. nnx.Not(X) excludes leaves of type X.
Other handy combinators:
-
nnx.Any(A, B)— match either. -
nnx.Not(A)— match everything except A. -
nnx.Everything()/nnx.Nothing()— sentinels for “all” / “none”.
The most common pattern in real code is just type-by-type — nnx.Param,
nnx.BatchStat, nnx.Cache — relying on the subclass hierarchy. The
All/Not composition is for the unusual case of “exactly this type
no subclasses.”
Worked example
class BatchStat(nnx.Variable):
pass
class MultiState(nnx.Module):
def __init__(self, in_features, hidden, rngs):
key = rngs.params()
self.kernel = nnx.Param(jax.random.normal(key, (in_features, hidden)) * 0.1)
self.bias = nnx.Param(jnp.zeros((hidden,)))
self.running_mean = BatchStat(jnp.zeros((hidden,)))
self.running_var = BatchStat(jnp.ones((hidden,)))
self.step_count = nnx.Variable(jnp.array(0.0))
self.cache = nnx.Cache(jnp.zeros((4, hidden)))
def __call__(self, x):
return x @ self.kernel + self.bias
m = MultiState(3, 5, rngs=nnx.Rngs(0))
# Total state: 6 leaves (kernel, bias, mean, var, step, cache).
print(len(jax.tree_util.tree_leaves(nnx.state(m)))) # 6
# Filter by Param: 2 leaves (kernel, bias).
print(len(jax.tree_util.tree_leaves(nnx.state(m, nnx.Param)))) # 2
# Filter by BatchStat: 2 leaves (running_mean, running_var).
print(len(jax.tree_util.tree_leaves(nnx.state(m, BatchStat)))) # 2
# Plain Variable (no subclasses): 1 leaf (step_count).
plain = nnx.All(nnx.Variable, nnx.Not(nnx.Param), nnx.Not(BatchStat), nnx.Not(nnx.Cache))
print(len(jax.tree_util.tree_leaves(nnx.state(m, plain)))) # 1
Notice the step_count is wrapped as nnx.Variable (the base class).
Filtering by nnx.Variable would match ALL of them (Param subclasses
Variable). To select only “exact” Variables, exclude every subclass
you’ve used.
Common pitfalls
-
nnx.state(model, nnx.Variable)returning everything. That’s working as designed (Variable is the root of the hierarchy). For “plain Variable only,” use theAll(... Not ...)pattern. -
Forgetting to
Not(nnx.Cache)when pulling plain Variables. Cache is a Variable subclass. -
Counting
model.attrinstead oftree_leaves(state). State pytrees have nested structure; onlytree_leavesflattens. -
Initializing
running_varto zeros. That’d be numerically wrong (division by zero). Convention: var starts at 1.
Problem
Write multi_state_filter(seed, x, hidden, eps):
-
Define
class BatchStat(nnx.Variable): passat module level. -
Define
MultiState(nnx.Module)with six attributes:-
self.kernel = nnx.Param(...)shape(in_features, hidden), initnormal * (1/sqrt(in_features)). -
self.bias = nnx.Param(jnp.zeros((hidden,))). -
self.running_mean = BatchStat(jnp.zeros((hidden,))). -
self.running_var = BatchStat(jnp.ones((hidden,))). -
self.step_count = nnx.Variable(jnp.array(0.0)). -
self.cache = nnx.Cache(jnp.zeros((4, hidden))). -
__call__:x @ self.kernel + self.bias.
-
-
Build
nnx.Rngs(int(seed)), instantiate the module (in_features=x.shape[-1],hidden=int(hidden)). -
Compute three counts:
-
params = nnx.state(model, nnx.Param)→ 2 leaves. -
batch_stats = nnx.state(model, BatchStat)→ 2 leaves. -
plain_vars = nnx.state(model, nnx.All(nnx.Variable, nnx.Not(nnx.Param), nnx.Not(BatchStat), nnx.Not(nnx.Cache)))→ 1 leaf.
-
-
Return
jnp.array([float(param_leaves), float(bs_leaves), float(plain_leaves)]). (epsis unused; multiply+ 0.0 * float(eps)to keep it referenced.)
Expected: [2.0, 2.0, 1.0] regardless of seed/shapes/eps.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
hidden: int (passed as float). -
eps: float (unused).
Output: length-3 array [2.0, 2.0, 1.0].
Hints
Sign in to attempt this problem and view the solution.