medium primitives

NNX Call vs Init

Why this matters

The biggest API difference between Linen and nnx is the death of init/apply. In Linen the model is a blueprint that produces a params pytree via model.init(key, x), and you run it via model.apply(params, x). Two function calls every forward.

In nnx, the parameters live ON the module instance — so the module IS the forward function. You build it once and then call it as many times as you want. No init, no apply, no params dict to thread.

This problem demonstrates idempotent calling: building a model once and calling it twice on the same input gives the same answer both times, because the model has no hidden side effects (yet — that comes with nnx.Variable later).

API: build once, call many times

rngs = nnx.Rngs(0)
model = MyLinear(in_features=4, out_features=8, rngs=rngs)

y1 = model(x)           # forward pass 1
y2 = model(x)           # forward pass 2 — exactly equal to y1
y3 = model(x_other)     # forward pass 3 — same params, different input

The model object is reusable. Each call reads the same parameter values (no mutation) and computes from the input. Compare with Linen:

# Linen — for contrast
model = MyLinear(features=8)
params = model.init(jax.random.PRNGKey(0), x)
y1 = model.apply(params, x)
y2 = model.apply(params, x)    # need to thread params through every call

Linen forces you to keep params in scope. nnx folds it into the model object.

Why does this matter for training?

During training, you DO update the parameters — but you do it via the optimizer (Phase 6), not by mutating the model directly. The forward pass itself is still pure; calling the model twice on the same x between optimizer steps gives identical outputs.

The exception is modules with nnx.Variable state (running averages, counters, KV caches). Those mutate per-call. We’ll see them in problems 5 and 17.

Worked example

class MyLinear(nnx.Module):
    def __init__(self, in_features, out_features, rngs):
        key = rngs.params()
        self.kernel = nnx.Param(
            jax.random.normal(key, (in_features, out_features))
            * (1.0 / jnp.sqrt(in_features))
        )
        self.bias = nnx.Param(jnp.zeros((out_features,)))

    def __call__(self, x):
        return x @ self.kernel + self.bias

model = MyLinear(in_features=4, out_features=2, rngs=nnx.Rngs(0))
x = jnp.array([1.0, 2.0, 3.0, 4.0])
print(model(x))      # [a, b]
print(model(x))      # [a, b] — identical

Common pitfalls

  • Re-building the model in a loop. If you write for _ in range(N): model = MyLinear(...), you’re re-initializing with the same seed each iteration. Build once, call N times.
  • Expecting the call to mutate. It doesn’t, unless you put an nnx.Variable inside and mutate it explicitly. Pure params + pure forward = no side effects.
  • Misreading “stateful”. nnx is described as “stateful” because the model owns its parameters as object state. That’s static state — it’s only mutated by you (or the optimizer), not by calling the model.
  • Forgetting the bias. The nnx.Param(jnp.zeros((out_features,))) bias starts at zero — important for the test outputs.

Problem

Write call_vs_init_forward(seed, x, features):

  1. Define MyLinear(nnx.Module) with kernel + bias (both nnx.Param). Forward: x @ kernel + bias.
  2. Build the model with nnx.Rngs(int(seed)), in_features=x.shape[-1], out_features=int(features).
  3. Call the model TWICE on the same x. Capture both outputs as out_first, out_second.
  4. Return a length-2 array: [out_first.sum(), out_second.sum()].

The two values must be identical — that is the whole point of the exercise.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).

Output: length-2 1-D array [s, s].

Hints

flax nnx idempotent

Sign in to attempt this problem and view the solution.