We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
jax.debug.callback
Why this matters
Sometimes you need to call arbitrary Python from inside a JAX trace — not to return a value, but for a side effect: writing a log line, updating a dashboard, or calling an external API.
jax.debug.callback(fn, *args) lets you do exactly that. Unlike
jax.pure_callback, the callback is not required to be pure. It can
have side effects. It does not return a value into the JAX computation;
it runs purely for effect.
Worked mini-example
import jax
import jax.numpy as jnp
step = 0
def log_step(arr):
global step
step += 1
print(f"step {step}: mean = {arr.mean()}")
@jax.jit
def train_step(x):
y = x ** 2
jax.debug.callback(log_step, y) # fires at runtime, not trace
return jnp.sum(y)
train_step(jnp.array([1.0, 2.0, 3.0]))
# prints: step 1: mean = 4.666666666...
Common pitfalls
-
Expecting a return value:
debug.callbackreturnsNoneto JAX — usejax.pure_callbackwhen you need the Python function’s output back in the trace. - Callback order: callbacks fire in trace order, interleaved with the rest of the computation. Don’t rely on ordering relative to XLA ops.
-
Not differentiable: gradients cannot flow through
debug.callback. - Performance: each callback introduces a host-sync that breaks async dispatch. Use sparingly in hot paths.
Problem
Implement debug_callback_passthrough(x) that:
-
Defines a no-op callback
cb(arr): pass. -
Calls
jax.debug.callback(cb, x)— demonstrating the pattern. -
Returns
x + 1.
-
x: 1-D JAX array.
Returns: 1-D array — x + 1.
Hints
jax
debug-callback
debugging
Sign in to attempt this problem and view the solution.