We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Multi-Param Module
Why this matters
Real modules carry MORE than just weights. They carry:
-
Trainable parameters —
nnx.Param. Receive gradients; updated by the optimizer. Examples: kernel matrices, biases, embedding tables. -
Non-trainable variables —
nnx.Variable. Don’t receive gradients, but DO live in the module’s state pytree (so they’re saved/restored with the model). Examples: counters, running averages, KV caches.
Both are stored as module attributes; the wrapper class tags the role. Without this distinction, an optimizer would either skip a tensor that should be updated (bug) or try to update one that shouldn’t be (worse).
This problem builds a small Linear with kernel + bias + a non-trainable counter — your first exposure to the param/variable split that drives the rest of nnx.
API: nnx.Param vs nnx.Variable
self.kernel = nnx.Param(jax.random.normal(key, (in_d, out_d))) # trainable
self.bias = nnx.Param(jnp.zeros((out_d,))) # trainable
self.count = nnx.Variable(jnp.array(0.0)) # not trainable
Both are wrappers around JAX arrays. The difference is the type — the state-filtering API uses it to separate trainables from non-trainables:
nnx.state(model, nnx.Param) # only kernel, bias
nnx.state(model, nnx.Variable) # all of them (Param subclasses Variable)
Notice the inheritance: nnx.Param is a subclass of nnx.Variable. So
“filter by Variable” returns everything; “filter by Param” returns only
the trainable subset. This subclass relationship is what lets you build
custom variable types (BatchStat, Cache, …) that the type system
treats distinctly. We’ll use that in problems 13 and 19.
Worked example
class MultiParam(nnx.Module):
def __init__(self, in_features, out_features, rngs):
key = rngs.params()
self.kernel = nnx.Param(
jax.random.normal(key, (in_features, out_features))
* (1.0 / jnp.sqrt(in_features))
)
self.bias = nnx.Param(jnp.zeros((out_features,)))
self.count = nnx.Variable(jnp.array(0.0)) # NOT a Param — won't be updated by optimizer
def __call__(self, x):
return x @ self.kernel + self.bias
model = MultiParam(4, 2, rngs=nnx.Rngs(0))
params = nnx.state(model, nnx.Param) # {kernel, bias}
all_state = nnx.state(model) # {kernel, bias, count}
Common pitfalls
-
Wrapping a counter as
nnx.Param. Then the optimizer will update it as if it were a weight. Usennx.Variablefor non-trainables. -
Forgetting to wrap. Plain
self.count = jnp.array(0.0)(no wrapper) makescounta static attribute — not part of the module’s state pytree. It will be ignored bynnx.split,nnx.state, and won’t be saved/restored. Sometimes that’s what you want (problem 18) — usually not. -
Mutating
countinside__call__without annnx.Variable. Plain attribute assignment of a jax array won’t survive split/merge round-trips. Wrap as Variable. -
biasas Variable instead of Param. Bias IS trainable — wrap asnnx.Param.
This problem (no mutation yet)
Define a Linear-like module with kernel (Param), bias (Param), and a
count Variable initialized to 0. The forward pass just computes
x @ kernel + bias — DO NOT increment count yet (we’ll do that in
problem 5). The variable is just sitting there as state that COULD be
mutated.
Sanity check: with all biases zero, the output is identical to problem 1’s bias-less Linear with the same seed.
Problem
Write multi_param_forward(seed, x, features):
-
Define
MultiParam(nnx.Module)with three attributes:self.kernel = nnx.Param(...),self.bias = nnx.Param(...),self.count = nnx.Variable(jnp.array(0.0)). -
Forward:
x @ self.kernel + self.bias. Don’t touchcount. -
Build
nnx.Rngs(int(seed)), instantiate the module (in_features=x.shape[-1],out_features=int(features)), return the forward output.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float).
Output: 1-D array of length features.
Hints
Sign in to attempt this problem and view the solution.