medium primitives

NNX Counters and Buffers

Why this matters

Lots of training code keeps small mutable scalars on the model:

  • step counter — number of optimizer updates so far. Used by LR schedules, logging, and the gradient clip-by-norm policy (“warm up the clip threshold for the first 1000 steps”).
  • Gating temperatures — for things like Mixture-of-Experts, you anneal a temperature from 5.0 -> 1.0 over training; that’s a mutable scalar.
  • Debug histograms — counters for “how often did the masked head activate”, etc.

These are NOT trainable; gradients shouldn’t flow through them. They’re NOT static metadata; they DO change during training. The right wrapper is nnx.Variable. Save it as a Module attribute, mutate self.step.value inside __call__, and you’re done.

This problem builds the simplest such counter — a step Variable that increments each call — and threads it through a Python loop. With no jit, the mutation is immediate and Python-eager.

API: counter as nnx.Variable

class StepCounter(nnx.Module):
    def __init__(self, rngs):
        _ = rngs.params()                          # consume one key for shape consistency
        self.step = nnx.Variable(jnp.array(0.0))

    def __call__(self):
        self.step.value = self.step.value + 1.0
        return self.step.value

Calling model() increments the counter and returns the new value. The counter survives between calls because it’s a module attribute, not a local variable.

Worked example

rngs = nnx.Rngs(0)
counter = StepCounter(rngs=rngs)
for _ in range(5):
    counter()
print(counter.step.value)    # 5.0

Five calls; counter is 5. Linear in the number of calls.

Compare with PyTorch / Linen

PyTorch:

self.register_buffer('step', torch.zeros(()))
# then later:
self.step += 1

register_buffer marks it as part of the module’s state but not a Parameter (so the optimizer skips it).

Linen:

step = self.variable("counters", "step", lambda: jnp.zeros(()))
# in apply: mutable=["counters"]
step.value = step.value + 1

nnx is closer to PyTorch’s ergonomics — no collection name, no mutable=....

Counter and JIT

Inside nnx.jit the counter would be threaded through split/merge automatically. We don’t jit here so the mutation is plain Python — call the model n times in a loop, observe step go up by 1 each time.

Without nnx.jit, this would be slow at scale (no XLA fusion); with nnx.jit, the Variable update is lifted into the jit boundary and runs fast. Lifting is a Sprint C3 topic.

Common pitfalls

  • Saving as nnx.Param instead of nnx.Variable. Then the optimizer tries to “train” the counter. Bad.
  • self.step += 1 (no .value). Replaces the Variable wrapper with a raw scalar; subsequent attribute access breaks.
  • int(self.step.value) inside the loop. Returns 0 on the first iteration (initial value), then races with the increment. Just keep it as a JAX scalar and read at the end.
  • n_calls = 0. Should leave the counter at 0; the loop runs 0 times.

Problem

Write counter_after_n_calls(seed, n_calls):

  1. Define StepCounter(nnx.Module) with __init__(self, rngs) consuming one key (for shape consistency) and self.step = nnx.Variable(jnp.array(0.0)). __call__(self) increments self.step.value by 1.0 and returns the new value.
  2. Build nnx.Rngs(int(seed)), instantiate StepCounter. Call it int(n_calls) times in a Python loop.
  3. Return jnp.array([float(model.step.value), float(int(n_calls))]).

Both should equal n_calls.

Inputs:

  • seed: int (passed as float).
  • n_calls: int (passed as float; can be 0).

Output: length-2 array [n_calls, n_calls].

Hints

flax nnx counter buffer

Sign in to attempt this problem and view the solution.