medium primitives

NNX State Filter

Why this matters

A real model carries multiple kinds of state at once:

  • Trainable parameters (nnx.Param) updated by the optimizer.
  • Running stats (nnx.BatchStat) mutated by BatchNorm during training.
  • KV caches (nnx.Cache) mutated by attention during generation.
  • Counters / debug buffers (nnx.Variable) updated by the user.

Each kind goes through a DIFFERENT pipeline:

  • Optimizers update only Params.
  • Checkpoints save everything (or you can split by type).
  • JIT/vmap may broadcast-share Params across replicas while sharding BatchStats per-device.

The filter API — nnx.state(model, FilterType) — is how you slice the state pytree by Variable type. You’ll use it constantly in real code.

Custom filters are easy: define a subclass of nnx.Variable and pass it as the filter type. This problem creates class BatchStat(nnx.Variable): pass (a custom marker), separate from nnx.Variable and nnx.Param, and confirms the count.

API: custom Variable subclasses

class BatchStat(nnx.Variable):
    pass

class MyModel(nnx.Module):
    def __init__(self, ...):
        self.kernel = nnx.Param(...)         # trainable
        self.bias = nnx.Param(...)           # trainable
        self.running_mean = BatchStat(...)    # batch_stat (custom)
        self.step = nnx.Variable(...)         # plain variable

Then:

nnx.state(model, nnx.Param)         # 2 leaves: kernel, bias
nnx.state(model, BatchStat)         # 1 leaf: running_mean
nnx.state(model)                    # 4 leaves: all of them

Filtering by nnx.Param keeps only Params (and Param subclasses, if any). Filtering by BatchStat keeps only BatchStats. Plain nnx.Variable filtering matches everything because Param, BatchStat, etc. all subclass Variable.

Why subclass nnx.Variable?

  • Filterable buckets. “Update only batch_stats” or “checkpoint only params” become one-liners.
  • Sharding annotations. You can attach a default partition spec to the subclass.
  • Domain semantics. A KVCache reads as different from a RunningMean even though both are mutable variables.

Flax provides nnx.BatchStat and nnx.Cache for the most common patterns. We use a custom one here so you can see the subclass machinery is the same.

Worked example

class BatchStat(nnx.Variable):
    pass

class FilterDemo(nnx.Module):
    def __init__(self, in_features, hidden, rngs):
        key = rngs.params()
        self.kernel = nnx.Param(jax.random.normal(key, (in_features, hidden)) * 0.1)
        self.bias = nnx.Param(jnp.zeros((hidden,)))
        self.running_mean = BatchStat(jnp.zeros((hidden,)))
        self.step = nnx.Variable(jnp.array(0.0))
    def __call__(self, x):
        return x @ self.kernel + self.bias

m = FilterDemo(in_features=4, hidden=6, rngs=nnx.Rngs(0))
print(len(jax.tree_util.tree_leaves(nnx.state(m))))               # 4
print(len(jax.tree_util.tree_leaves(nnx.state(m, nnx.Param))))    # 2
print(len(jax.tree_util.tree_leaves(nnx.state(m, BatchStat))))    # 1

Common pitfalls

  • Defining BatchStat as a subclass of nnx.Param. Then it’d be filterable as a Param and the optimizer would try to train it. Subclass nnx.Variable.
  • Same name as nnx.BatchStat. Flax’s built-in is also called BatchStat. They’re different classes by identity; if you import from flax.nnx import BatchStat, you’d be using the built-in. Here we define our own to keep the example self-contained.
  • Filtering by nnx.Variable. That matches everything. Use a subclass to get a strict subset.
  • Casting tree_leaves results to ints. Tests want floats; use float(...) in the return tuple.

Problem

Write state_filter_count(seed, x, hidden):

  1. Define class BatchStat(nnx.Variable): pass outside the module.
  2. Define FilterDemo(nnx.Module) with:
    • kernel: nnx.Param shape (in_features, hidden), init normal * (1/sqrt(in_features)).
    • bias: nnx.Param shape (hidden,), init zeros.
    • running_mean: BatchStat shape (hidden,), init zeros.
    • step: nnx.Variable scalar, init 0.
    • __call__: x @ kernel + bias.
  3. Build nnx.Rngs(int(seed)), instantiate the module (in_features=x.shape[-1], hidden=int(hidden)).
  4. Compute three leaf counts:
    • total = nnx.state(model) → 4 leaves.
    • params = nnx.state(model, nnx.Param) → 2 leaves.
    • bs = nnx.state(model, BatchStat) → 1 leaf.
  5. Return jnp.array([float(total_leaves), float(param_leaves), float(bs_leaves)]).

Expected: [4.0, 2.0, 1.0] regardless of seed/shapes.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • hidden: int (passed as float).

Output: length-3 array [4.0, 2.0, 1.0].

Hints

flax nnx filter variable-subclass

Sign in to attempt this problem and view the solution.