medium primitives

NNX Variable vs Param

Why this matters

The Variable type system is what makes nnx work with optimizers, sharding, and surgery without ad-hoc heuristics. Every piece of mutable state in a module is wrapped in a Variable subclass, and the type tells the rest of the framework what to do:

  • nnx.Param (subclass of Variable) — trainable weight; the optimizer updates it.
  • nnx.Variable (the base class) — generic mutable state; the optimizer ignores it; nnx.state(..., nnx.Param) skips it.
  • nnx.BatchStat (subclass of Variable) — running mean/var for BatchNorm.
  • nnx.Cache (subclass of Variable) — KV cache for attention.
  • Custom subclasses — class MyTracker(nnx.Variable): pass — for any bespoke role you need to filter on later.

nnx.state(model, FilterType) walks the state pytree and keeps only the leaves whose Variable subclass matches FilterType. This is the exact mechanism Optax uses to find the trainable subset.

This problem builds a tiny module with one Param and two Variables, then filters by type to confirm the count.

API: nnx.state with a type filter

class MixedState(nnx.Module):
    def __init__(self, in_features, out_features, rngs):
        key = rngs.params()
        self.kernel = nnx.Param(
            jax.random.normal(key, (in_features, out_features))
            * (1.0 / jnp.sqrt(in_features))
        )
        self.running_mean = nnx.Variable(jnp.zeros((out_features,)))
        self.step_count = nnx.Variable(jnp.array(0.0))

    def __call__(self, x):
        return x @ self.kernel

model = MixedState(in_features=4, out_features=6, rngs=nnx.Rngs(0))

all_state = nnx.state(model)              # everything: 3 leaves
param_state = nnx.state(model, nnx.Param)  # only Params: 1 leaf

nnx.state(model) returns the full state pytree. Adding a type filter keeps only matching leaves; the structure is preserved (other branches just become empty).

Subclass semantics

nnx’s filter is “is-instance-of” — nnx.Param matches nnx.Param AND any subclass of nnx.Param. nnx.Variable matches everything (because nnx.Param, nnx.BatchStat, etc. all subclass it). Pick the most specific type for clarity.

For “exactly type X, no subclasses,” combine with nnx.Not:

nnx.state(model, nnx.All(nnx.Variable, nnx.Not(nnx.Param)))   # Variable but not Param

We’ll use that idiom in problem 19.

Worked example

rngs = nnx.Rngs(0)
m = MixedState(in_features=3, out_features=4, rngs=rngs)
params = nnx.state(m, nnx.Param)
print(len(jax.tree_util.tree_leaves(params)))    # 1 — only kernel
print(m.kernel.value.flatten()[0])                # first kernel element

Common pitfalls

  • Wrapping a buffer as nnx.Param. Then your optimizer drifts the buffer like a weight. Use nnx.Variable (or a custom subclass) for non-trainables.
  • Filtering with the wrong type. nnx.state(model, nnx.BatchStat) returns nothing if you used nnx.Variable for your running mean. Subclass intentionally if you need filterable buckets.
  • Counting leaves vs counting Variables. A nnx.Param wrapping a single 2-D array is still ONE leaf — leaves are arrays.
  • Reading .value vs unwrapping in math. Both work; .value is explicit and survives static analysis.

Problem

Write variable_vs_param(seed, x, features):

  1. Define MixedState(nnx.Module):
    • self.kernel = nnx.Param(...) shape (in_features, out_features), init jax.random.normal * (1/sqrt(in_features)).
    • self.running_mean = nnx.Variable(jnp.zeros((out_features,))).
    • self.step_count = nnx.Variable(jnp.array(0.0)).
    • __call__ returns x @ self.kernel (no bias, no other state used).
  2. Build nnx.Rngs(int(seed)), instantiate the module, call it on x to get out.
  3. params = nnx.state(model, nnx.Param); count its leaves (n_param_leaves = len(jax.tree_util.tree_leaves(params))). Should be 1.
  4. Get kernel_first = model.kernel.value.reshape(-1)[0] (the first element of the kernel as a flat array).
  5. Return jnp.array([float(n_param_leaves), float(kernel_first), float(out.sum())]).

Expected: [1.0, kernel[0,0], out.sum()].

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).

Output: length-3 array.

Hints

flax nnx param variable filter

Sign in to attempt this problem and view the solution.