We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Eager Debug
Why this matters
This is the showcase problem for the workflow advantage that
nnx has over Linen. In Linen, every forward call is a pure
functional state, output = model.apply(state, x) round-trip —
you can’t keep mutable Python state across calls without threading
it through every callsite.
In nnx, when you DON’T jit the model, Module instances are plain Python objects with mutable attributes. You can:
- Mutate counters between calls.
-
Set breakpoints inside
__call__and inspectself.attr. -
Print intermediate values with
print()(nojax.debug.print). - Modify hyperparameters mid-iteration without rebuilding.
Eager mode is the debugging gear. Get a model working iteratively
in eager mode, then drop on nnx.jit for production speed.
API: nnx.Variable
nnx.Param is the canonical “trainable parameter” wrapper, but
nnx has a more general nnx.Variable that holds any kind of
mutable state — counters, EMA accumulators, KV caches.
self.counter = nnx.Variable(jnp.array(0, dtype=jnp.int32))
def __call__(self, x):
self.counter.value = self.counter.value + 1
return self.l1(x)
On every call, the counter increments. Between calls, the value
persists in self.counter.value because self is the same Python
object.
The choice of nnx.Variable vs nnx.Param:
-
nnx.Param: trained by default (gradient flows,wrt=nnx.Parampicks it up). -
nnx.Variable: bare mutable storage; doesn’t get gradients unless you opt in.
For a step counter, nnx.Variable is the right choice — you don’t
want gradients flowing into it.
Why this fails in Linen
In Linen, you’d write:
class CountingModule(nn.Module):
@nn.compact
def __call__(self, x):
counter = self.variable("state", "counter", lambda: jnp.array(0))
counter.value = counter.value + 1 # OOPS
return nn.Dense(4)(x)
But model.apply(params, x) returns (output, mutable_collections)
only when you opt into the mutable collection. The next call
needs the updated state passed back IN as part of params. You
can do it, but it’s a two-line ceremony every call:
out, new_state = model.apply({"params": p, "state": s}, x, mutable=["state"])
s = new_state["state"]
Forget the s = ... line and your “counter” silently doesn’t
increment.
In nnx eager mode, the same code is self.counter.value = self.counter.value + 1 and Just Works. That’s the real workflow
advantage — eager nnx code reads like PyTorch.
Worked example
class CountingModel(nnx.Module):
def __init__(self, in_features, features, rngs):
self.l1 = nnx.Linear(in_features, features, rngs=rngs)
self.counter = nnx.Variable(jnp.array(0, dtype=jnp.int32))
def __call__(self, x):
self.counter.value = self.counter.value + 1
return self.l1(x)
model = CountingModel(3, 4, rngs=nnx.Rngs(0))
for _ in range(5):
out = model(x)
print(int(model.counter.value)) # 5
Five calls; counter is 5; out is the result of the fifth call. No state threading.
What about jit?
Wrapping with nnx.jit traces the function. The first call
increments the counter; subsequent calls re-use the trace, but
the counter still increments because nnx.jit knows how to handle
in-place mutation of nnx.Variable. (Magic: nnx splits the model
on entry and merges on exit, so the mutation is observable from
the outside.)
But for debugging, jit is the wrong tool. Skip jit, set
breakpoints, inspect self.counter.value between calls. Once
things work, drop jit on the train step.
Common pitfalls
-
Plain Python int as counter. Without
nnx.Variable, the counter goes into the static graphdef, not the state. Mutating it works for one Python session but doesn’t survive split/merge. -
self.counter += 1. That binds a new attribute namedcounterto a fresh JAX array, replacing thennx.Variablewrapper. Useself.counter.value = self.counter.value + 1to preserve the wrapper. -
Reading the counter inside the forward and using it as a
JAX array AT TRACE TIME. Inside
nnx.jit, the counter is a Tracer with a known shape but unknown value — fine for dynamic use, broken if you try to use it as a Python int. -
Forgetting that the state survives across calls. Some users
expect a
Module()instantiation per call (the Linen pattern). In nnx, you build once and reuse — the state accumulates.
Problem
Write eager_debug_step_count(seed, x, features, n_calls):
-
Define
CountingModel(nnx.Module)with onennx.Linear(in -> features) and aself.counter = nnx.Variable(jnp.array(0, dtype=jnp.int32)). -
__call__(x)incrementsself.counter.valueby 1, then returnsself.l1(x). -
Build the model. Call
model(x)n_callstimes in a Python loop (no jit). Capture the output of the last call. -
Return
jnp.array([float(final_count), float(last_out.sum())]).
The first entry is n_calls. The second is the sum of the last
forward’s output (same as the first forward’s, since l1 is
deterministic and inputs don’t change — but the counter has
incremented).
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float). -
n_calls: int (passed as float).
Output: length-2 [final_count, output_sum].
Hints
Sign in to attempt this problem and view the solution.