medium primitives

NNX Mutability Demo

Why this matters

nnx is eagerly mutable. Inside __call__, you can write self.some_var.value = new_value and the change PERSISTS between calls. This is the single biggest API delta from Linen, where mutable state forces you to declare collections (mutable=["batch_stats"]) and return updated state from apply as part of a tuple.

With nnx, mutable state is just attribute assignment — the same shape as PyTorch (self.bn1.running_mean.copy_(...)). The catch is that JAX’s pure-functional traced regions (jit, vmap, grad) don’t tolerate side effects. nnx solves this with split/merge (problem 12) and lifted transforms (nnx.jit, nnx.vmap — Sprint C3).

For this problem, we run the model in plain Python (no jit), so the mutation is real and immediate.

API: nnx.Variable mutation

self.count = nnx.Variable(jnp.array(0.0))   # initialize

def __call__(self, x):
    self.count.value = self.count.value + 1.0   # increment
    return x

Things to notice:

  1. .value to read AND write. The Variable is a wrapper; the underlying array lives at .value. (In recent flax versions you can also write var[...] = new_value; the .value form remains supported.)
  2. Use +, not +=. JAX arrays are immutable — += would fail. Compute a new array and reassign.
  3. No mutable=[...]. No collections, no return-tuple. Just write the attribute.
  4. The change survives between calls. If you call the model 3 times, count.value becomes 3.0.

Compare with Linen’s running-mean pattern (problem 26 of the Linen track):

# Linen
running_mean = self.variable("batch_stats", "running_mean", lambda: jnp.zeros((d,)))
if not self.is_initializing():
    running_mean.value = momentum * running_mean.value + (1 - momentum) * batch_mean
# AND in apply: mutable=["batch_stats"]; out, updated = model.apply(...)

vs. nnx:

self.running_mean.value = momentum * self.running_mean.value + (1 - momentum) * batch_mean
# That's it. No mutable=, no return-tuple.

Way less ceremony.

Worked example

class Counter(nnx.Module):
    def __init__(self, rngs):
        _ = rngs.params()                              # consume one key for shape consistency
        self.count = nnx.Variable(jnp.array(0.0))

    def __call__(self, x):
        self.count.value = self.count.value + 1.0
        return x

model = Counter(rngs=nnx.Rngs(0))
x = jnp.ones((4,))
for _ in range(3):
    out = model(x)
print(model.count.value)    # 3.0  — the count survived
print(out)                  # [1, 1, 1, 1] — x is unchanged

Common pitfalls

  • self.count = self.count + 1 (forgetting .value). This replaces the Variable wrapper with a raw array, breaking subsequent .value access and removing it from the state pytree.
  • Trying to mutate inside nnx.jit directly. That’s where lifted transforms come in (Sprint C3). For now, run eagerly.
  • Forgetting that the array is immutable. self.count.value += 1 is not allowed; do self.count.value = self.count.value + 1.0.
  • Off-by-one on n_calls. A loop with range(int(n_calls)) runs exactly n_calls times. After the loop, count.value == n_calls.
  • Using nnx.Param for the counter. Then the optimizer would update it. Use nnx.Variable.

Problem

Write mutability_increment_count(seed, x, n_calls):

  1. Define Counter(nnx.Module) with one nnx.Variable count initialized to jnp.array(0.0). The __init__ should still take rngs and consume one key (just call rngs.params() and discard) — keeps the construction shape consistent across the track.
  2. __call__(self, x) increments self.count.value by 1.0 and returns x unchanged.
  3. Build nnx.Rngs(int(seed)), instantiate Counter, call it int(n_calls) times in a Python loop on x. Capture the final output (last_x).
  4. Return a length-2 array: [model.count.value, last_x.sum()].

Expected: count equals n_calls; last_x.sum() equals x.sum() (the model returns x unchanged).

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • n_calls: int (passed as float).

Output: length-2 array [n_calls, x.sum()].

Hints

flax nnx mutability variable

Sign in to attempt this problem and view the solution.