We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jax.debug.print with Multiple kwargs
Why this matters
A single jax.debug.print call can log multiple named values at once by
passing extra keyword arguments. This is more efficient than separate calls
and keeps related values on the same line in your debug output. During a
training loop you often want to simultaneously inspect a loss value and its
gradient norm β one call is cleaner than two.
Worked mini-example
import jax
import jax.numpy as jnp
@jax.jit
def step(x):
loss = jnp.sum(x ** 2)
grad_norm = jnp.sqrt(jnp.sum((2 * x) ** 2))
jax.debug.print("loss={loss} grad_norm={grad_norm}",
loss=loss, grad_norm=grad_norm)
return loss
step(jnp.array([1.0, 2.0, 3.0]))
# stdout: loss=14.0 grad_norm=7.483314...
Each {name} placeholder in the format string corresponds to a keyword
argument of the same name. Extra kwargs beyond the placeholders are ignored;
missing kwargs raise a KeyError at runtime.
Common pitfalls
-
Format string placeholders must match kwarg names exactly β a typo
in either the placeholder or the kwarg name raises a
KeyError. -
debug.printwrites to stdout at runtime β it is suppressed by some test runners that capture stdout; the test here checks only the return value. -
Donβt use regular
printβ insidejax.jit,printfires at trace time with an abstract tracer object, not the real value. -
Side effect cost β every
debug.printcall introduces a host synchronisation. Disable or remove in production.
Problem
Implement debug_print_two_things(x) that:
-
Computes
s = jnp.sum(x)andm = jnp.mean(x). -
Calls
jax.debug.print("sum={s} mean={m}", s=s, m=m)to log both values. -
Returns
s.
-
x: 1-D JAX array.
Returns: scalar β sum(x).
Hints
Sign in to attempt this problem and view the solution.