We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX State Filter
Why this matters
A real model carries multiple kinds of state at once:
-
Trainable parameters (
nnx.Param) updated by the optimizer. -
Running stats (
nnx.BatchStat) mutated by BatchNorm during training. -
KV caches (
nnx.Cache) mutated by attention during generation. -
Counters / debug buffers (
nnx.Variable) updated by the user.
Each kind goes through a DIFFERENT pipeline:
- Optimizers update only Params.
- Checkpoints save everything (or you can split by type).
- JIT/vmap may broadcast-share Params across replicas while sharding BatchStats per-device.
The filter API — nnx.state(model, FilterType) — is how you slice the
state pytree by Variable type. You’ll use it constantly in real code.
Custom filters are easy: define a subclass of nnx.Variable and pass
it as the filter type. This problem creates class BatchStat(nnx.Variable): pass (a custom marker), separate from nnx.Variable and nnx.Param,
and confirms the count.
API: custom Variable subclasses
class BatchStat(nnx.Variable):
pass
class MyModel(nnx.Module):
def __init__(self, ...):
self.kernel = nnx.Param(...) # trainable
self.bias = nnx.Param(...) # trainable
self.running_mean = BatchStat(...) # batch_stat (custom)
self.step = nnx.Variable(...) # plain variable
Then:
nnx.state(model, nnx.Param) # 2 leaves: kernel, bias
nnx.state(model, BatchStat) # 1 leaf: running_mean
nnx.state(model) # 4 leaves: all of them
Filtering by nnx.Param keeps only Params (and Param subclasses, if any).
Filtering by BatchStat keeps only BatchStats. Plain nnx.Variable
filtering matches everything because Param, BatchStat, etc. all
subclass Variable.
Why subclass nnx.Variable?
- Filterable buckets. “Update only batch_stats” or “checkpoint only params” become one-liners.
- Sharding annotations. You can attach a default partition spec to the subclass.
-
Domain semantics. A
KVCachereads as different from aRunningMeaneven though both are mutable variables.
Flax provides nnx.BatchStat and nnx.Cache for the most common
patterns. We use a custom one here so you can see the subclass
machinery is the same.
Worked example
class BatchStat(nnx.Variable):
pass
class FilterDemo(nnx.Module):
def __init__(self, in_features, hidden, rngs):
key = rngs.params()
self.kernel = nnx.Param(jax.random.normal(key, (in_features, hidden)) * 0.1)
self.bias = nnx.Param(jnp.zeros((hidden,)))
self.running_mean = BatchStat(jnp.zeros((hidden,)))
self.step = nnx.Variable(jnp.array(0.0))
def __call__(self, x):
return x @ self.kernel + self.bias
m = FilterDemo(in_features=4, hidden=6, rngs=nnx.Rngs(0))
print(len(jax.tree_util.tree_leaves(nnx.state(m)))) # 4
print(len(jax.tree_util.tree_leaves(nnx.state(m, nnx.Param)))) # 2
print(len(jax.tree_util.tree_leaves(nnx.state(m, BatchStat)))) # 1
Common pitfalls
-
Defining
BatchStatas a subclass ofnnx.Param. Then it’d be filterable as a Param and the optimizer would try to train it. Subclassnnx.Variable. -
Same name as
nnx.BatchStat. Flax’s built-in is also calledBatchStat. They’re different classes by identity; if you importfrom flax.nnx import BatchStat, you’d be using the built-in. Here we define our own to keep the example self-contained. -
Filtering by
nnx.Variable. That matches everything. Use a subclass to get a strict subset. -
Casting
tree_leavesresults to ints. Tests want floats; usefloat(...)in the return tuple.
Problem
Write state_filter_count(seed, x, hidden):
-
Define
class BatchStat(nnx.Variable): passoutside the module. -
Define
FilterDemo(nnx.Module)with:-
kernel: nnx.Paramshape(in_features, hidden), initnormal * (1/sqrt(in_features)). -
bias: nnx.Paramshape(hidden,), init zeros. -
running_mean: BatchStatshape(hidden,), init zeros. -
step: nnx.Variablescalar, init 0. -
__call__:x @ kernel + bias.
-
-
Build
nnx.Rngs(int(seed)), instantiate the module (in_features=x.shape[-1],hidden=int(hidden)). -
Compute three leaf counts:
-
total = nnx.state(model)→ 4 leaves. -
params = nnx.state(model, nnx.Param)→ 2 leaves. -
bs = nnx.state(model, BatchStat)→ 1 leaf.
-
-
Return
jnp.array([float(total_leaves), float(param_leaves), float(bs_leaves)]).
Expected: [4.0, 2.0, 1.0] regardless of seed/shapes.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
hidden: int (passed as float).
Output: length-3 array [4.0, 2.0, 1.0].
Hints
Sign in to attempt this problem and view the solution.