We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Mutability Demo
Why this matters
nnx is eagerly mutable. Inside __call__, you can write
self.some_var.value = new_value and the change PERSISTS between calls.
This is the single biggest API delta from Linen, where mutable state
forces you to declare collections (mutable=["batch_stats"]) and
return updated state from apply as part of a tuple.
With nnx, mutable state is just attribute assignment — the same shape
as PyTorch (self.bn1.running_mean.copy_(...)). The catch is that
JAX’s pure-functional traced regions (jit, vmap, grad) don’t tolerate
side effects. nnx solves this with split/merge (problem 12) and
lifted transforms (nnx.jit, nnx.vmap — Sprint C3).
For this problem, we run the model in plain Python (no jit), so
the mutation is real and immediate.
API: nnx.Variable mutation
self.count = nnx.Variable(jnp.array(0.0)) # initialize
def __call__(self, x):
self.count.value = self.count.value + 1.0 # increment
return x
Things to notice:
-
.valueto read AND write. The Variable is a wrapper; the underlying array lives at.value. (In recent flax versions you can also writevar[...] = new_value; the.valueform remains supported.) -
Use
+, not+=. JAX arrays are immutable —+=would fail. Compute a new array and reassign. -
No
mutable=[...]. No collections, no return-tuple. Just write the attribute. -
The change survives between calls. If you call the model 3 times,
count.valuebecomes 3.0.
Compare with Linen’s running-mean pattern (problem 26 of the Linen track):
# Linen
running_mean = self.variable("batch_stats", "running_mean", lambda: jnp.zeros((d,)))
if not self.is_initializing():
running_mean.value = momentum * running_mean.value + (1 - momentum) * batch_mean
# AND in apply: mutable=["batch_stats"]; out, updated = model.apply(...)
vs. nnx:
self.running_mean.value = momentum * self.running_mean.value + (1 - momentum) * batch_mean
# That's it. No mutable=, no return-tuple.
Way less ceremony.
Worked example
class Counter(nnx.Module):
def __init__(self, rngs):
_ = rngs.params() # consume one key for shape consistency
self.count = nnx.Variable(jnp.array(0.0))
def __call__(self, x):
self.count.value = self.count.value + 1.0
return x
model = Counter(rngs=nnx.Rngs(0))
x = jnp.ones((4,))
for _ in range(3):
out = model(x)
print(model.count.value) # 3.0 — the count survived
print(out) # [1, 1, 1, 1] — x is unchanged
Common pitfalls
-
self.count = self.count + 1(forgetting.value). This replaces the Variable wrapper with a raw array, breaking subsequent.valueaccess and removing it from the state pytree. -
Trying to mutate inside
nnx.jitdirectly. That’s where lifted transforms come in (Sprint C3). For now, run eagerly. -
Forgetting that the array is immutable.
self.count.value += 1is not allowed; doself.count.value = self.count.value + 1.0. -
Off-by-one on
n_calls. A loop withrange(int(n_calls))runs exactlyn_callstimes. After the loop, count.value == n_calls. -
Using
nnx.Paramfor the counter. Then the optimizer would update it. Usennx.Variable.
Problem
Write mutability_increment_count(seed, x, n_calls):
-
Define
Counter(nnx.Module)with onennx.Variablecountinitialized tojnp.array(0.0). The__init__should still takerngsand consume one key (just callrngs.params()and discard) — keeps the construction shape consistent across the track. -
__call__(self, x)incrementsself.count.valueby 1.0 and returnsxunchanged. -
Build
nnx.Rngs(int(seed)), instantiateCounter, call itint(n_calls)times in a Python loop onx. Capture the final output (last_x). -
Return a length-2 array:
[model.count.value, last_x.sum()].
Expected: count equals n_calls; last_x.sum() equals x.sum()
(the model returns x unchanged).
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
n_calls: int (passed as float).
Output: length-2 array [n_calls, x.sum()].
Hints
Sign in to attempt this problem and view the solution.