easy primitives

jax.debug.print

Why this matters

A common frustration when learning JAX: you add a print() inside a jit-compiled function to inspect a value, and the print fires once at trace time with an abstract tracer object — not the actual numbers you wanted. Subsequent calls produce nothing.

jax.debug.print(fmt, **kwargs) solves this. It is a JAX primitive that survives compilation and fires at every call, printing the real runtime values. This is the correct tool for inspecting intermediate values inside jitted functions during training.

Worked mini-example

import jax
import jax.numpy as jnp

@jax.jit
def compute(x):
    y = x ** 2
    jax.debug.print("y = {y}", y=y)   # prints at every call
    return jnp.sum(y)

compute(jnp.array([1.0, 2.0, 3.0]))
# stdout: y = [1. 4. 9.]

compute(jnp.array([4.0, 5.0]))
# stdout: y = [16. 25.]

Common pitfalls

  • Don’t use print() inside jit — it fires only at trace time with abstract tracers, not real values. Use jax.debug.print instead.
  • Format string uses {name} placeholders — the kwargs must match the placeholder names exactly.
  • Output goes to stderr, not stdout — this is intentional so debug output doesn’t mix with structured output.
  • Use sparingly — every debug.print introduces a host callback that breaks async dispatch. Useful for debugging, not production.

Problem

Implement debug_printed_compute(x) that:

  1. Computes s = jnp.sum(x).
  2. Calls jax.debug.print("sum: {s}", s=s) to log the sum.
  3. Returns s.
  • x: 1-D jax array.

Returns: scalar — sum(x).

Hints

jax debug-print debugging

Sign in to attempt this problem and view the solution.