We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Call vs Init
Why this matters
The biggest API difference between Linen and nnx is the death of
init/apply. In Linen the model is a blueprint that produces a params
pytree via model.init(key, x), and you run it via model.apply(params, x).
Two function calls every forward.
In nnx, the parameters live ON the module instance — so the module IS the
forward function. You build it once and then call it as many times as you
want. No init, no apply, no params dict to thread.
This problem demonstrates idempotent calling: building a model once and
calling it twice on the same input gives the same answer both times,
because the model has no hidden side effects (yet — that comes with
nnx.Variable later).
API: build once, call many times
rngs = nnx.Rngs(0)
model = MyLinear(in_features=4, out_features=8, rngs=rngs)
y1 = model(x) # forward pass 1
y2 = model(x) # forward pass 2 — exactly equal to y1
y3 = model(x_other) # forward pass 3 — same params, different input
The model object is reusable. Each call reads the same parameter values (no mutation) and computes from the input. Compare with Linen:
# Linen — for contrast
model = MyLinear(features=8)
params = model.init(jax.random.PRNGKey(0), x)
y1 = model.apply(params, x)
y2 = model.apply(params, x) # need to thread params through every call
Linen forces you to keep params in scope. nnx folds it into the model
object.
Why does this matter for training?
During training, you DO update the parameters — but you do it via the
optimizer (Phase 6), not by mutating the model directly. The forward pass
itself is still pure; calling the model twice on the same x between
optimizer steps gives identical outputs.
The exception is modules with nnx.Variable state (running averages,
counters, KV caches). Those mutate per-call. We’ll see them in problems
5 and 17.
Worked example
class MyLinear(nnx.Module):
def __init__(self, in_features, out_features, rngs):
key = rngs.params()
self.kernel = nnx.Param(
jax.random.normal(key, (in_features, out_features))
* (1.0 / jnp.sqrt(in_features))
)
self.bias = nnx.Param(jnp.zeros((out_features,)))
def __call__(self, x):
return x @ self.kernel + self.bias
model = MyLinear(in_features=4, out_features=2, rngs=nnx.Rngs(0))
x = jnp.array([1.0, 2.0, 3.0, 4.0])
print(model(x)) # [a, b]
print(model(x)) # [a, b] — identical
Common pitfalls
-
Re-building the model in a loop. If you write
for _ in range(N): model = MyLinear(...), you’re re-initializing with the same seed each iteration. Build once, call N times. -
Expecting the call to mutate. It doesn’t, unless you put an
nnx.Variableinside and mutate it explicitly. Pure params + pure forward = no side effects. - Misreading “stateful”. nnx is described as “stateful” because the model owns its parameters as object state. That’s static state — it’s only mutated by you (or the optimizer), not by calling the model.
-
Forgetting the bias. The
nnx.Param(jnp.zeros((out_features,)))bias starts at zero — important for the test outputs.
Problem
Write call_vs_init_forward(seed, x, features):
-
Define
MyLinear(nnx.Module)with kernel + bias (bothnnx.Param). Forward:x @ kernel + bias. -
Build the model with
nnx.Rngs(int(seed)),in_features=x.shape[-1],out_features=int(features). -
Call the model TWICE on the same
x. Capture both outputs asout_first,out_second. -
Return a length-2 array:
[out_first.sum(), out_second.sum()].
The two values must be identical — that is the whole point of the exercise.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float).
Output: length-2 1-D array [s, s].
Hints
Sign in to attempt this problem and view the solution.