We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX nnx.jit
Why this matters
Calling a Flax NNX module eagerly works — model(x) runs every op as
Python sees it, with no XLA fusion, no caching, no tracing. For a
one-off forward, that’s fine. For training, where the same forward
runs millions of times, it’s a disaster: you pay Python dispatch
overhead per op, every step.
jax.jit is the standard fix, but it operates on pure functions
over pytrees. An nnx Module is neither — it’s a stateful Python
object that owns nnx.Variables. To use raw jax.jit, you’d have
to manually nnx.split(model), jit a function that takes
(graphdef, state, x), then nnx.merge inside. Possible, but
boilerplate-heavy and easy to get wrong.
nnx.jit is the lifted version of jax.jit that understands nnx
state. You hand it a function that takes a Module as its first
argument; under the hood it does the split/merge for you, hashes
the GraphDef as static, treats the State pytree as the dynamic
arg, and caches by GraphDef + input shapes/dtypes. Subsequent
calls hit the cache — no retracing.
The canonical incantation
model = nnx.Linear(in_features, out_features, rngs=nnx.Rngs(seed))
@nnx.jit
def forward(model, x):
return model(x)
out = forward(model, x) # first call: trace + compile
out = forward(model, x) # subsequent: cache hit
Three things to notice:
-
The function takes
modellike any other Python arg.nnx.jitknows what to do with it. - You can call it repeatedly with the same model — caching kicks in as long as input shapes/dtypes are stable.
- The return value is a plain JAX array, the same one a non-jitted forward would produce. Numerics are bit-identical to eager.
Cache key, in detail
nnx.jit caches on (GraphDef of model) + (shape/dtype of every other input). So:
- Calling with a model whose structure changed (e.g. you swapped a layer) → cache miss, retrace.
-
Calling with
xof a different shape → cache miss, retrace. -
Calling with the same
modelafteroptimizer.update(...)mutated its params → cache hit. Param values are dynamic; only the GraphDef is static. This is how you canjit(train_step)and have it stay fast across thousands of optimizer steps.
Common pitfalls
-
Wrapping
jax.jitdirectly around a model call.jax.jit(model)tries to tracemodelitself as a pytree, which fails (or worse, silently leaks state). Usennx.jitinstead. - Recompiling on every call. If your input shapes vary (variable sequence length, dynamic batch size), every new shape triggers a fresh compile. Pad to fixed shapes if you can.
-
Mutating the model from inside the jitted function in unusual ways.
Standard ops (parameter updates via
optimizer.update, BatchNorm stats, RNG counters) are tracked. Custom mutations on plain Python attrs may not be — keep mutation explicit and through nnx primitives.
Problem
Implement jit_forward(seed, x, features, num_calls):
-
Build
model = nnx.Linear(x.shape[-1], int(features), rngs=nnx.Rngs(int(seed))). -
Define
@nnx.jitdef fwd(model, x): return model(x). -
Call
fwd(model, x)int(num_calls)times in a Python loop. (First call traces+compiles; the rest hit the cache.) -
Return the LAST output flattened:
out.reshape(-1).
The point is not to time anything — the test verifies that the jit’d output equals the eager output (and that calling repeatedly does not corrupt the model).
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D_in). -
features: float (cast to int) —D_out. -
num_calls: float (cast to int) — how many times to call.
Output: 1-D (N * features,) — flattened final output.
Hints
Sign in to attempt this problem and view the solution.