medium primitives

NNX Debug Print

Why this matters

The classical pain of jit’d JAX is that you can’t print(x) from inside a jitted function — at trace time, x is a Tracer, not an array, so the printed values are abstract shapes. By the time the values are real, you’re outside the jit boundary and there’s nothing to inspect.

jax.debug.print is the answer: a jit-safe print that queues the print to happen at runtime, when actual values exist. It works inside jit, vmap, scan, grad — anywhere.

@jax.jit
def f(x):
    jax.debug.print("intermediate: {v}", v=x * 2)
    return x ** 2

Calling f(jnp.array(3.0)) actually prints intermediate: 6.0 — the materialized value, at runtime, after the jit has executed.

API

jax.debug.print(format_string, *positional, **keyword)

The format string uses {name} placeholders. Positional and keyword arguments are interpreted as JAX arrays and queued for runtime materialization. A few useful shapes:

jax.debug.print("got x={x}, y={y}", x=x, y=y)
jax.debug.print("first three: {v}", v=arr[:3])
jax.debug.print("norm = {n}", n=jnp.linalg.norm(x))

Order: by default, prints inside jit are unordered — they happen in some order that depends on XLA scheduling, not necessarily source order. Pass ordered=True for source-order:

jax.debug.print("step {s}: loss={l}", s=step, l=loss, ordered=True)

Ordered prints are slightly slower (they force a barrier) but they make logs readable in training loops.

Why this matters in nnx

nnx Modules can use jax.debug.print directly in __call__. The print becomes part of the traced computation; it survives nnx.jit, nnx.vmap, etc. The Module’s forward function still returns a real array — the print is a pure side effect alongside the return value.

Putting prints inside Module __call__ is also natural for “debug instrumentation”: flip a flag, inject prints in every block, observe activations, flip the flag off. No separate apply_fn plumbing.

Worked example

class DebugModel(nnx.Module):
    def __init__(self, in_features, features, rngs):
        self.l1 = nnx.Linear(in_features, features, rngs=rngs)
        self.l2 = nnx.Linear(features, features, rngs=rngs)

    def __call__(self, x):
        h = self.l1(x)
        jax.debug.print("layer 1 norm: {n}", n=jnp.linalg.norm(h))
        return self.l2(h)

Calling DebugModel()(x) prints the L2 norm of the post-l1 activation at runtime, then returns the final l2 output. In a training loop you’d see one print per step — useful for spotting when activations explode (norm shoots up) or vanish (norm goes to 0).

Compare with plain print

print(jnp.linalg.norm(h)) inside a non-jitted function works fine — h is a concrete array. But the moment you wrap with jit or vmap, plain print shows a Tracer object, not a value. jax.debug.print survives the trace.

Inside an unjitted nnx Module (eager mode, like pos 97), plain print ALSO works. But code in nnx is often jitted later, so using jax.debug.print from the start is a safer habit — it keeps working regardless of whether the model is later wrapped with nnx.jit.

Common pitfalls

  • Calling print instead of jax.debug.print inside jit. Plain print happens at trace time and shows a Tracer; you lose the value information.
  • Forgetting that prints are unordered by default. If two prints in the same forward pass come out in surprising order, pass ordered=True.
  • Treating the print as a return value. jax.debug.print returns None; the function still has to return something.
  • Heavy-format-string overhead in tight loops. Each jax.debug.print queues a host-side callback. In a 100k-step training loop, that’s noticeable — only enable for debugging, not in production.

Problem

Write debug_print_passthrough(seed, x, features):

  1. Define a DebugModel(nnx.Module) with two nnx.Linear layers. First maps x.shape[-1] -> features, second maps features -> features.
  2. In __call__(x): compute h = self.l1(x). Call jax.debug.print("layer 1 norm: {n}", n=jnp.linalg.norm(h)). Return self.l2(h).
  3. The function builds the model and returns model(x). The print happens; the return value is the real second-layer output.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).

Output: 1-D (features,) — the second layer’s output.

Hints

flax nnx debug jax-debug-print

Sign in to attempt this problem and view the solution.