We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jax.debug.print
Why this matters
A common frustration when learning JAX: you add a print() inside a
jit-compiled function to inspect a value, and the print fires once
at trace time with an abstract tracer object — not the actual
numbers you wanted. Subsequent calls produce nothing.
jax.debug.print(fmt, **kwargs) solves this. It is a JAX primitive
that survives compilation and fires at every call, printing the
real runtime values. This is the correct tool for inspecting
intermediate values inside jitted functions during training.
Worked mini-example
import jax
import jax.numpy as jnp
@jax.jit
def compute(x):
y = x ** 2
jax.debug.print("y = {y}", y=y) # prints at every call
return jnp.sum(y)
compute(jnp.array([1.0, 2.0, 3.0]))
# stdout: y = [1. 4. 9.]
compute(jnp.array([4.0, 5.0]))
# stdout: y = [16. 25.]
Common pitfalls
-
Don’t use
print()insidejit— it fires only at trace time with abstract tracers, not real values. Usejax.debug.printinstead. -
Format string uses
{name}placeholders — the kwargs must match the placeholder names exactly. - Output goes to stderr, not stdout — this is intentional so debug output doesn’t mix with structured output.
-
Use sparingly — every
debug.printintroduces a host callback that breaks async dispatch. Useful for debugging, not production.
Problem
Implement debug_printed_compute(x) that:
-
Computes
s = jnp.sum(x). -
Calls
jax.debug.print("sum: {s}", s=s)to log the sum. -
Returns
s.
-
x: 1-D jax array.
Returns: scalar — sum(x).
Hints
Sign in to attempt this problem and view the solution.