We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Counters and Buffers
Why this matters
Lots of training code keeps small mutable scalars on the model:
-
stepcounter — number of optimizer updates so far. Used by LR schedules, logging, and the gradient clip-by-norm policy (“warm up the clip threshold for the first 1000 steps”). -
Gating temperatures — for things like Mixture-of-Experts, you
anneal a temperature from
5.0 -> 1.0over training; that’s a mutable scalar. - Debug histograms — counters for “how often did the masked head activate”, etc.
These are NOT trainable; gradients shouldn’t flow through them. They’re
NOT static metadata; they DO change during training. The right wrapper
is nnx.Variable. Save it as a Module attribute, mutate
self.step.value inside __call__, and you’re done.
This problem builds the simplest such counter — a step Variable that
increments each call — and threads it through a Python loop. With no
jit, the mutation is immediate and Python-eager.
API: counter as nnx.Variable
class StepCounter(nnx.Module):
def __init__(self, rngs):
_ = rngs.params() # consume one key for shape consistency
self.step = nnx.Variable(jnp.array(0.0))
def __call__(self):
self.step.value = self.step.value + 1.0
return self.step.value
Calling model() increments the counter and returns the new value.
The counter survives between calls because it’s a module attribute,
not a local variable.
Worked example
rngs = nnx.Rngs(0)
counter = StepCounter(rngs=rngs)
for _ in range(5):
counter()
print(counter.step.value) # 5.0
Five calls; counter is 5. Linear in the number of calls.
Compare with PyTorch / Linen
PyTorch:
self.register_buffer('step', torch.zeros(()))
# then later:
self.step += 1
register_buffer marks it as part of the module’s state but not a
Parameter (so the optimizer skips it).
Linen:
step = self.variable("counters", "step", lambda: jnp.zeros(()))
# in apply: mutable=["counters"]
step.value = step.value + 1
nnx is closer to PyTorch’s ergonomics — no collection name, no mutable=....
Counter and JIT
Inside nnx.jit the counter would be threaded through split/merge
automatically. We don’t jit here so the mutation is plain Python — call
the model n times in a loop, observe step go up by 1 each time.
Without nnx.jit, this would be slow at scale (no XLA fusion); with
nnx.jit, the Variable update is lifted into the jit boundary and runs
fast. Lifting is a Sprint C3 topic.
Common pitfalls
-
Saving as
nnx.Paraminstead ofnnx.Variable. Then the optimizer tries to “train” the counter. Bad. -
self.step += 1(no.value). Replaces the Variable wrapper with a raw scalar; subsequent attribute access breaks. -
int(self.step.value)inside the loop. Returns 0 on the first iteration (initial value), then races with the increment. Just keep it as a JAX scalar and read at the end. -
n_calls = 0. Should leave the counter at 0; the loop runs 0 times.
Problem
Write counter_after_n_calls(seed, n_calls):
-
Define
StepCounter(nnx.Module)with__init__(self, rngs)consuming one key (for shape consistency) andself.step = nnx.Variable(jnp.array(0.0)).__call__(self)incrementsself.step.valueby 1.0 and returns the new value. -
Build
nnx.Rngs(int(seed)), instantiateStepCounter. Call itint(n_calls)times in a Python loop. -
Return
jnp.array([float(model.step.value), float(int(n_calls))]).
Both should equal n_calls.
Inputs:
-
seed: int (passed as float). -
n_calls: int (passed as float; can be 0).
Output: length-2 array [n_calls, n_calls].
Hints
Sign in to attempt this problem and view the solution.