medium primitives

NNX Multi-Param Module

Why this matters

Real modules carry MORE than just weights. They carry:

  • Trainable parametersnnx.Param. Receive gradients; updated by the optimizer. Examples: kernel matrices, biases, embedding tables.
  • Non-trainable variablesnnx.Variable. Don’t receive gradients, but DO live in the module’s state pytree (so they’re saved/restored with the model). Examples: counters, running averages, KV caches.

Both are stored as module attributes; the wrapper class tags the role. Without this distinction, an optimizer would either skip a tensor that should be updated (bug) or try to update one that shouldn’t be (worse).

This problem builds a small Linear with kernel + bias + a non-trainable counter — your first exposure to the param/variable split that drives the rest of nnx.

API: nnx.Param vs nnx.Variable

self.kernel = nnx.Param(jax.random.normal(key, (in_d, out_d)))   # trainable
self.bias = nnx.Param(jnp.zeros((out_d,)))                        # trainable
self.count = nnx.Variable(jnp.array(0.0))                         # not trainable

Both are wrappers around JAX arrays. The difference is the type — the state-filtering API uses it to separate trainables from non-trainables:

nnx.state(model, nnx.Param)       # only kernel, bias
nnx.state(model, nnx.Variable)    # all of them (Param subclasses Variable)

Notice the inheritance: nnx.Param is a subclass of nnx.Variable. So “filter by Variable” returns everything; “filter by Param” returns only the trainable subset. This subclass relationship is what lets you build custom variable types (BatchStat, Cache, …) that the type system treats distinctly. We’ll use that in problems 13 and 19.

Worked example

class MultiParam(nnx.Module):
    def __init__(self, in_features, out_features, rngs):
        key = rngs.params()
        self.kernel = nnx.Param(
            jax.random.normal(key, (in_features, out_features))
            * (1.0 / jnp.sqrt(in_features))
        )
        self.bias = nnx.Param(jnp.zeros((out_features,)))
        self.count = nnx.Variable(jnp.array(0.0))   # NOT a Param — won't be updated by optimizer

    def __call__(self, x):
        return x @ self.kernel + self.bias

model = MultiParam(4, 2, rngs=nnx.Rngs(0))
params = nnx.state(model, nnx.Param)              # {kernel, bias}
all_state = nnx.state(model)                       # {kernel, bias, count}

Common pitfalls

  • Wrapping a counter as nnx.Param. Then the optimizer will update it as if it were a weight. Use nnx.Variable for non-trainables.
  • Forgetting to wrap. Plain self.count = jnp.array(0.0) (no wrapper) makes count a static attribute — not part of the module’s state pytree. It will be ignored by nnx.split, nnx.state, and won’t be saved/restored. Sometimes that’s what you want (problem 18) — usually not.
  • Mutating count inside __call__ without an nnx.Variable. Plain attribute assignment of a jax array won’t survive split/merge round-trips. Wrap as Variable.
  • bias as Variable instead of Param. Bias IS trainable — wrap as nnx.Param.

This problem (no mutation yet)

Define a Linear-like module with kernel (Param), bias (Param), and a count Variable initialized to 0. The forward pass just computes x @ kernel + bias — DO NOT increment count yet (we’ll do that in problem 5). The variable is just sitting there as state that COULD be mutated.

Sanity check: with all biases zero, the output is identical to problem 1’s bias-less Linear with the same seed.

Problem

Write multi_param_forward(seed, x, features):

  1. Define MultiParam(nnx.Module) with three attributes: self.kernel = nnx.Param(...), self.bias = nnx.Param(...), self.count = nnx.Variable(jnp.array(0.0)).
  2. Forward: x @ self.kernel + self.bias. Don’t touch count.
  3. Build nnx.Rngs(int(seed)), instantiate the module (in_features=x.shape[-1], out_features=int(features)), return the forward output.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).

Output: 1-D array of length features.

Hints

flax nnx param variable

Sign in to attempt this problem and view the solution.