We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Variable vs Param
Why this matters
The Variable type system is what makes nnx work with optimizers, sharding,
and surgery without ad-hoc heuristics. Every piece of mutable state in a
module is wrapped in a Variable subclass, and the type tells the rest
of the framework what to do:
-
nnx.Param(subclass ofVariable) — trainable weight; the optimizer updates it. -
nnx.Variable(the base class) — generic mutable state; the optimizer ignores it;nnx.state(..., nnx.Param)skips it. -
nnx.BatchStat(subclass ofVariable) — running mean/var for BatchNorm. -
nnx.Cache(subclass ofVariable) — KV cache for attention. -
Custom subclasses —
class MyTracker(nnx.Variable): pass— for any bespoke role you need to filter on later.
nnx.state(model, FilterType) walks the state pytree and keeps only the
leaves whose Variable subclass matches FilterType. This is the exact
mechanism Optax uses to find the trainable subset.
This problem builds a tiny module with one Param and two Variables, then filters by type to confirm the count.
API: nnx.state with a type filter
class MixedState(nnx.Module):
def __init__(self, in_features, out_features, rngs):
key = rngs.params()
self.kernel = nnx.Param(
jax.random.normal(key, (in_features, out_features))
* (1.0 / jnp.sqrt(in_features))
)
self.running_mean = nnx.Variable(jnp.zeros((out_features,)))
self.step_count = nnx.Variable(jnp.array(0.0))
def __call__(self, x):
return x @ self.kernel
model = MixedState(in_features=4, out_features=6, rngs=nnx.Rngs(0))
all_state = nnx.state(model) # everything: 3 leaves
param_state = nnx.state(model, nnx.Param) # only Params: 1 leaf
nnx.state(model) returns the full state pytree. Adding a type filter
keeps only matching leaves; the structure is preserved (other branches
just become empty).
Subclass semantics
nnx’s filter is “is-instance-of” — nnx.Param matches nnx.Param AND any
subclass of nnx.Param. nnx.Variable matches everything (because
nnx.Param, nnx.BatchStat, etc. all subclass it). Pick the most
specific type for clarity.
For “exactly type X, no subclasses,” combine with nnx.Not:
nnx.state(model, nnx.All(nnx.Variable, nnx.Not(nnx.Param))) # Variable but not Param
We’ll use that idiom in problem 19.
Worked example
rngs = nnx.Rngs(0)
m = MixedState(in_features=3, out_features=4, rngs=rngs)
params = nnx.state(m, nnx.Param)
print(len(jax.tree_util.tree_leaves(params))) # 1 — only kernel
print(m.kernel.value.flatten()[0]) # first kernel element
Common pitfalls
-
Wrapping a buffer as
nnx.Param. Then your optimizer drifts the buffer like a weight. Usennx.Variable(or a custom subclass) for non-trainables. -
Filtering with the wrong type.
nnx.state(model, nnx.BatchStat)returns nothing if you usednnx.Variablefor your running mean. Subclass intentionally if you need filterable buckets. -
Counting leaves vs counting Variables. A
nnx.Paramwrapping a single 2-D array is still ONE leaf — leaves are arrays. -
Reading
.valuevs unwrapping in math. Both work;.valueis explicit and survives static analysis.
Problem
Write variable_vs_param(seed, x, features):
-
Define
MixedState(nnx.Module):-
self.kernel = nnx.Param(...)shape(in_features, out_features), initjax.random.normal * (1/sqrt(in_features)). -
self.running_mean = nnx.Variable(jnp.zeros((out_features,))). -
self.step_count = nnx.Variable(jnp.array(0.0)). -
__call__returnsx @ self.kernel(no bias, no other state used).
-
-
Build
nnx.Rngs(int(seed)), instantiate the module, call it onxto getout. -
params = nnx.state(model, nnx.Param); count its leaves (n_param_leaves = len(jax.tree_util.tree_leaves(params))). Should be 1. -
Get
kernel_first = model.kernel.value.reshape(-1)[0](the first element of the kernel as a flat array). -
Return
jnp.array([float(n_param_leaves), float(kernel_first), float(out.sum())]).
Expected: [1.0, kernel[0,0], out.sum()].
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float).
Output: length-3 array.
Hints
Sign in to attempt this problem and view the solution.