We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Debug Print
Why this matters
The classical pain of jit’d JAX is that you can’t print(x) from
inside a jitted function — at trace time, x is a Tracer, not
an array, so the printed values are abstract shapes. By the time
the values are real, you’re outside the jit boundary and there’s
nothing to inspect.
jax.debug.print is the answer: a jit-safe print that queues
the print to happen at runtime, when actual values exist. It
works inside jit, vmap, scan, grad — anywhere.
@jax.jit
def f(x):
jax.debug.print("intermediate: {v}", v=x * 2)
return x ** 2
Calling f(jnp.array(3.0)) actually prints intermediate: 6.0 —
the materialized value, at runtime, after the jit has executed.
API
jax.debug.print(format_string, *positional, **keyword)
The format string uses {name} placeholders. Positional and keyword
arguments are interpreted as JAX arrays and queued for runtime
materialization. A few useful shapes:
jax.debug.print("got x={x}, y={y}", x=x, y=y)
jax.debug.print("first three: {v}", v=arr[:3])
jax.debug.print("norm = {n}", n=jnp.linalg.norm(x))
Order: by default, prints inside jit are unordered — they
happen in some order that depends on XLA scheduling, not necessarily
source order. Pass ordered=True for source-order:
jax.debug.print("step {s}: loss={l}", s=step, l=loss, ordered=True)
Ordered prints are slightly slower (they force a barrier) but they make logs readable in training loops.
Why this matters in nnx
nnx Modules can use jax.debug.print directly in __call__. The
print becomes part of the traced computation; it survives
nnx.jit, nnx.vmap, etc. The Module’s forward function still
returns a real array — the print is a pure side effect alongside
the return value.
Putting prints inside Module __call__ is also natural for
“debug instrumentation”: flip a flag, inject prints in every
block, observe activations, flip the flag off. No separate
apply_fn plumbing.
Worked example
class DebugModel(nnx.Module):
def __init__(self, in_features, features, rngs):
self.l1 = nnx.Linear(in_features, features, rngs=rngs)
self.l2 = nnx.Linear(features, features, rngs=rngs)
def __call__(self, x):
h = self.l1(x)
jax.debug.print("layer 1 norm: {n}", n=jnp.linalg.norm(h))
return self.l2(h)
Calling DebugModel()(x) prints the L2 norm of the post-l1
activation at runtime, then returns the final l2 output. In a
training loop you’d see one print per step — useful for spotting
when activations explode (norm shoots up) or vanish (norm goes to
0).
Compare with plain print
print(jnp.linalg.norm(h)) inside a non-jitted function works
fine — h is a concrete array. But the moment you wrap with
jit or vmap, plain print shows a Tracer object, not a
value. jax.debug.print survives the trace.
Inside an unjitted nnx Module (eager mode, like pos 97), plain
print ALSO works. But code in nnx is often jitted later, so
using jax.debug.print from the start is a safer habit — it
keeps working regardless of whether the model is later wrapped
with nnx.jit.
Common pitfalls
-
Calling
printinstead ofjax.debug.printinsidejit. Plainprinthappens at trace time and shows a Tracer; you lose the value information. -
Forgetting that prints are unordered by default. If two
prints in the same forward pass come out in surprising order,
pass
ordered=True. -
Treating the print as a return value.
jax.debug.printreturnsNone; the function still has toreturnsomething. -
Heavy-format-string overhead in tight loops. Each
jax.debug.printqueues a host-side callback. In a 100k-step training loop, that’s noticeable — only enable for debugging, not in production.
Problem
Write debug_print_passthrough(seed, x, features):
-
Define a
DebugModel(nnx.Module)with twonnx.Linearlayers. First mapsx.shape[-1] -> features, second mapsfeatures -> features. -
In
__call__(x): computeh = self.l1(x). Calljax.debug.print("layer 1 norm: {n}", n=jnp.linalg.norm(h)). Returnself.l2(h). -
The function builds the model and returns
model(x). The print happens; the return value is the real second-layer output.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float).
Output: 1-D (features,) — the second layer’s output.
Hints
Sign in to attempt this problem and view the solution.