We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
nn.jit (Flax-aware JIT lift)
Why this matters
JIT compilation is how JAX gets fast: instead of executing each
op eagerly, the JIT traces your function, builds an XLA program,
compiles it once, and runs the compiled binary on every call.
For training loops this is an unconditional win — jax.jit
around your train_step is table stakes.
For ordinary Python functions you’d just use jax.jit. For Flax
Modules, you have a choice:
-
jax.jit(model.apply)— JIT the apply function from outside. Works fine for the common case. -
nn.jit(MyModule)— return a NEW Module class whose internal apply is wrapped withjax.jit. The lifted version.
nn.jit matters when:
-
You want the JIT boundary INSIDE a Module hierarchy. E.g.,
a parent Module wraps a child via
nn.jit(Child)so only the child’s body gets jitted, not the entire parent. -
You’re composing multiple Flax lifts and want them all to be
in the Module-class form so they nest cleanly (
nn.scan(nn.jit(...)), etc.). -
You want Flax’s variable collections (
params,batch_stats, RNGs) handled correctly across the JIT boundary without manually threading them.
For straightforward training, jax.jit(model.apply) is simpler.
nn.jit is the lifted version for Module-level compositions.
The canonical incantation
JitDense = nn.jit(nn.Dense)
model = JitDense(features=F)
params = model.init(rng, x)
out = model.apply(params, x)
Init and apply behave identically to the un-jitted version (same output, same param shapes); the only effect is that internal JAX ops get compiled inside an XLA program when apply is called.
Forward-pass equivalence
Like nn.checkpoint (pos 82), nn.jit produces the same numeric
output as the un-wrapped Module. JIT doesn’t change semantics —
only timing.
Because of compile-time costs, the FIRST call after wrapping will be slower than subsequent calls (the trace + compile happens once). For benchmarking, always do a warm-up call before measuring.
When does jax.jit(model.apply) suffice?
Most of the time. Specifically:
-
If your top-level training step is one big
jax.jit, it already covers all the Modules inside. Addingnn.jiton individual children is redundant. -
If you only ever call
model.apply(notmodel.init), you can justjit_apply = jax.jit(model.apply)once at startup.
Reach for nn.jit when you need finer-grained, Module-level
JIT boundaries — e.g., a very large model where you don’t
want one giant XLA program.
Common pitfalls
-
Wrapping the instance:
nn.jit(nn.Dense(features=F))doesn’t work; pass the CLASS, then instantiate. - Expecting different results: it’s identical numerics.
-
Re-creating
JitDenseeach call: each top-levelnn.jit(nn.Dense)triggers a fresh compile. Define it once at module-load time.
Problem
Implement nn_jit_forward(seed, x, features):
-
Build
JitDense = nn.jit(nn.Dense). -
Instantiate
model = JitDense(features=features). -
Init with
PRNGKey(seed)andx. - Apply.
- Return the output flattened.
Inputs:
-
seed: int. -
x: 1-D(D,). -
features: int F.
Output: 1-D (F,).
Hints
Sign in to attempt this problem and view the solution.