medium primitives

NNX Eager Debug

Why this matters

This is the showcase problem for the workflow advantage that nnx has over Linen. In Linen, every forward call is a pure functional state, output = model.apply(state, x) round-trip — you can’t keep mutable Python state across calls without threading it through every callsite.

In nnx, when you DON’T jit the model, Module instances are plain Python objects with mutable attributes. You can:

  • Mutate counters between calls.
  • Set breakpoints inside __call__ and inspect self.attr.
  • Print intermediate values with print() (no jax.debug.print).
  • Modify hyperparameters mid-iteration without rebuilding.

Eager mode is the debugging gear. Get a model working iteratively in eager mode, then drop on nnx.jit for production speed.

API: nnx.Variable

nnx.Param is the canonical “trainable parameter” wrapper, but nnx has a more general nnx.Variable that holds any kind of mutable state — counters, EMA accumulators, KV caches.

self.counter = nnx.Variable(jnp.array(0, dtype=jnp.int32))

def __call__(self, x):
    self.counter.value = self.counter.value + 1
    return self.l1(x)

On every call, the counter increments. Between calls, the value persists in self.counter.value because self is the same Python object.

The choice of nnx.Variable vs nnx.Param:

  • nnx.Param: trained by default (gradient flows, wrt=nnx.Param picks it up).
  • nnx.Variable: bare mutable storage; doesn’t get gradients unless you opt in.

For a step counter, nnx.Variable is the right choice — you don’t want gradients flowing into it.

Why this fails in Linen

In Linen, you’d write:

class CountingModule(nn.Module):
    @nn.compact
    def __call__(self, x):
        counter = self.variable("state", "counter", lambda: jnp.array(0))
        counter.value = counter.value + 1                  # OOPS
        return nn.Dense(4)(x)

But model.apply(params, x) returns (output, mutable_collections) only when you opt into the mutable collection. The next call needs the updated state passed back IN as part of params. You can do it, but it’s a two-line ceremony every call:

out, new_state = model.apply({"params": p, "state": s}, x, mutable=["state"])
s = new_state["state"]

Forget the s = ... line and your “counter” silently doesn’t increment.

In nnx eager mode, the same code is self.counter.value = self.counter.value + 1 and Just Works. That’s the real workflow advantage — eager nnx code reads like PyTorch.

Worked example

class CountingModel(nnx.Module):
    def __init__(self, in_features, features, rngs):
        self.l1 = nnx.Linear(in_features, features, rngs=rngs)
        self.counter = nnx.Variable(jnp.array(0, dtype=jnp.int32))

    def __call__(self, x):
        self.counter.value = self.counter.value + 1
        return self.l1(x)

model = CountingModel(3, 4, rngs=nnx.Rngs(0))
for _ in range(5):
    out = model(x)
print(int(model.counter.value))      # 5

Five calls; counter is 5; out is the result of the fifth call. No state threading.

What about jit?

Wrapping with nnx.jit traces the function. The first call increments the counter; subsequent calls re-use the trace, but the counter still increments because nnx.jit knows how to handle in-place mutation of nnx.Variable. (Magic: nnx splits the model on entry and merges on exit, so the mutation is observable from the outside.)

But for debugging, jit is the wrong tool. Skip jit, set breakpoints, inspect self.counter.value between calls. Once things work, drop jit on the train step.

Common pitfalls

  • Plain Python int as counter. Without nnx.Variable, the counter goes into the static graphdef, not the state. Mutating it works for one Python session but doesn’t survive split/merge.
  • self.counter += 1. That binds a new attribute named counter to a fresh JAX array, replacing the nnx.Variable wrapper. Use self.counter.value = self.counter.value + 1 to preserve the wrapper.
  • Reading the counter inside the forward and using it as a JAX array AT TRACE TIME. Inside nnx.jit, the counter is a Tracer with a known shape but unknown value — fine for dynamic use, broken if you try to use it as a Python int.
  • Forgetting that the state survives across calls. Some users expect a Module() instantiation per call (the Linen pattern). In nnx, you build once and reuse — the state accumulates.

Problem

Write eager_debug_step_count(seed, x, features, n_calls):

  1. Define CountingModel(nnx.Module) with one nnx.Linear (in -> features) and a self.counter = nnx.Variable(jnp.array(0, dtype=jnp.int32)).
  2. __call__(x) increments self.counter.value by 1, then returns self.l1(x).
  3. Build the model. Call model(x) n_calls times in a Python loop (no jit). Capture the output of the last call.
  4. Return jnp.array([float(final_count), float(last_out.sum())]).

The first entry is n_calls. The second is the sum of the last forward’s output (same as the first forward’s, since l1 is deterministic and inputs don’t change — but the counter has incremented).

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).
  • n_calls: int (passed as float).

Output: length-2 [final_count, output_sum].

Hints

flax nnx debug eager variable mutability

Sign in to attempt this problem and view the solution.