medium primitives

nn.jit (Flax-aware JIT lift)

Why this matters

JIT compilation is how JAX gets fast: instead of executing each op eagerly, the JIT traces your function, builds an XLA program, compiles it once, and runs the compiled binary on every call. For training loops this is an unconditional win — jax.jit around your train_step is table stakes.

For ordinary Python functions you’d just use jax.jit. For Flax Modules, you have a choice:

  • jax.jit(model.apply) — JIT the apply function from outside. Works fine for the common case.
  • nn.jit(MyModule) — return a NEW Module class whose internal apply is wrapped with jax.jit. The lifted version.

nn.jit matters when:

  • You want the JIT boundary INSIDE a Module hierarchy. E.g., a parent Module wraps a child via nn.jit(Child) so only the child’s body gets jitted, not the entire parent.
  • You’re composing multiple Flax lifts and want them all to be in the Module-class form so they nest cleanly (nn.scan(nn.jit(...)), etc.).
  • You want Flax’s variable collections (params, batch_stats, RNGs) handled correctly across the JIT boundary without manually threading them.

For straightforward training, jax.jit(model.apply) is simpler. nn.jit is the lifted version for Module-level compositions.

The canonical incantation

JitDense = nn.jit(nn.Dense)
model = JitDense(features=F)
params = model.init(rng, x)
out = model.apply(params, x)

Init and apply behave identically to the un-jitted version (same output, same param shapes); the only effect is that internal JAX ops get compiled inside an XLA program when apply is called.

Forward-pass equivalence

Like nn.checkpoint (pos 82), nn.jit produces the same numeric output as the un-wrapped Module. JIT doesn’t change semantics — only timing.

Because of compile-time costs, the FIRST call after wrapping will be slower than subsequent calls (the trace + compile happens once). For benchmarking, always do a warm-up call before measuring.

When does jax.jit(model.apply) suffice?

Most of the time. Specifically:

  • If your top-level training step is one big jax.jit, it already covers all the Modules inside. Adding nn.jit on individual children is redundant.
  • If you only ever call model.apply (not model.init), you can just jit_apply = jax.jit(model.apply) once at startup.

Reach for nn.jit when you need finer-grained, Module-level JIT boundaries — e.g., a very large model where you don’t want one giant XLA program.

Common pitfalls

  • Wrapping the instance: nn.jit(nn.Dense(features=F)) doesn’t work; pass the CLASS, then instantiate.
  • Expecting different results: it’s identical numerics.
  • Re-creating JitDense each call: each top-level nn.jit(nn.Dense) triggers a fresh compile. Define it once at module-load time.

Problem

Implement nn_jit_forward(seed, x, features):

  1. Build JitDense = nn.jit(nn.Dense).
  2. Instantiate model = JitDense(features=features).
  3. Init with PRNGKey(seed) and x.
  4. Apply.
  5. Return the output flattened.

Inputs:

  • seed: int.
  • x: 1-D (D,).
  • features: int F.

Output: 1-D (F,).

Hints

flax nn-jit lifted-transforms compilation

Sign in to attempt this problem and view the solution.