medium primitives

NNX nnx.jit

Why this matters

Calling a Flax NNX module eagerly works — model(x) runs every op as Python sees it, with no XLA fusion, no caching, no tracing. For a one-off forward, that’s fine. For training, where the same forward runs millions of times, it’s a disaster: you pay Python dispatch overhead per op, every step.

jax.jit is the standard fix, but it operates on pure functions over pytrees. An nnx Module is neither — it’s a stateful Python object that owns nnx.Variables. To use raw jax.jit, you’d have to manually nnx.split(model), jit a function that takes (graphdef, state, x), then nnx.merge inside. Possible, but boilerplate-heavy and easy to get wrong.

nnx.jit is the lifted version of jax.jit that understands nnx state. You hand it a function that takes a Module as its first argument; under the hood it does the split/merge for you, hashes the GraphDef as static, treats the State pytree as the dynamic arg, and caches by GraphDef + input shapes/dtypes. Subsequent calls hit the cache — no retracing.

The canonical incantation

model = nnx.Linear(in_features, out_features, rngs=nnx.Rngs(seed))

@nnx.jit
def forward(model, x):
    return model(x)

out = forward(model, x)   # first call: trace + compile
out = forward(model, x)   # subsequent: cache hit

Three things to notice:

  • The function takes model like any other Python arg. nnx.jit knows what to do with it.
  • You can call it repeatedly with the same model — caching kicks in as long as input shapes/dtypes are stable.
  • The return value is a plain JAX array, the same one a non-jitted forward would produce. Numerics are bit-identical to eager.

Cache key, in detail

nnx.jit caches on (GraphDef of model) + (shape/dtype of every other input). So:

  • Calling with a model whose structure changed (e.g. you swapped a layer) → cache miss, retrace.
  • Calling with x of a different shape → cache miss, retrace.
  • Calling with the same model after optimizer.update(...) mutated its params → cache hit. Param values are dynamic; only the GraphDef is static. This is how you can jit(train_step) and have it stay fast across thousands of optimizer steps.

Common pitfalls

  • Wrapping jax.jit directly around a model call. jax.jit(model) tries to trace model itself as a pytree, which fails (or worse, silently leaks state). Use nnx.jit instead.
  • Recompiling on every call. If your input shapes vary (variable sequence length, dynamic batch size), every new shape triggers a fresh compile. Pad to fixed shapes if you can.
  • Mutating the model from inside the jitted function in unusual ways. Standard ops (parameter updates via optimizer.update, BatchNorm stats, RNG counters) are tracked. Custom mutations on plain Python attrs may not be — keep mutation explicit and through nnx primitives.

Problem

Implement jit_forward(seed, x, features, num_calls):

  1. Build model = nnx.Linear(x.shape[-1], int(features), rngs=nnx.Rngs(int(seed))).
  2. Define @nnx.jit def fwd(model, x): return model(x).
  3. Call fwd(model, x) int(num_calls) times in a Python loop. (First call traces+compiles; the rest hit the cache.)
  4. Return the LAST output flattened: out.reshape(-1).

The point is not to time anything — the test verifies that the jit’d output equals the eager output (and that calling repeatedly does not corrupt the model).

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D_in).
  • features: float (cast to int) — D_out.
  • num_calls: float (cast to int) — how many times to call.

Output: 1-D (N * features,) — flattened final output.

Hints

flax nnx lifted-transforms jit

Sign in to attempt this problem and view the solution.