medium primitives

jax.debug.print with Multiple kwargs

Why this matters

A single jax.debug.print call can log multiple named values at once by passing extra keyword arguments. This is more efficient than separate calls and keeps related values on the same line in your debug output. During a training loop you often want to simultaneously inspect a loss value and its gradient norm β€” one call is cleaner than two.

Worked mini-example

import jax
import jax.numpy as jnp

@jax.jit
def step(x):
    loss = jnp.sum(x ** 2)
    grad_norm = jnp.sqrt(jnp.sum((2 * x) ** 2))
    jax.debug.print("loss={loss} grad_norm={grad_norm}",
                    loss=loss, grad_norm=grad_norm)
    return loss

step(jnp.array([1.0, 2.0, 3.0]))
# stdout: loss=14.0 grad_norm=7.483314...

Each {name} placeholder in the format string corresponds to a keyword argument of the same name. Extra kwargs beyond the placeholders are ignored; missing kwargs raise a KeyError at runtime.

Common pitfalls

  • Format string placeholders must match kwarg names exactly β€” a typo in either the placeholder or the kwarg name raises a KeyError.
  • debug.print writes to stdout at runtime β€” it is suppressed by some test runners that capture stdout; the test here checks only the return value.
  • Don’t use regular print β€” inside jax.jit, print fires at trace time with an abstract tracer object, not the real value.
  • Side effect cost β€” every debug.print call introduces a host synchronisation. Disable or remove in production.

Problem

Implement debug_print_two_things(x) that:

  1. Computes s = jnp.sum(x) and m = jnp.mean(x).
  2. Calls jax.debug.print("sum={s} mean={m}", s=s, m=m) to log both values.
  3. Returns s.
  • x: 1-D JAX array.

Returns: scalar β€” sum(x).

Hints

jax debug-print kwargs

Sign in to attempt this problem and view the solution.