We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
jax.debug.callback for Side Effects
Why this matters
jax.debug.callback(fn, *args) lets you call an arbitrary Python function
from inside a JAX trace for its side effects only. The callback receives
concrete array values at runtime (not abstract tracers) and can write logs,
update dashboards, call external APIs, or accumulate training metrics — all
from within a jax.jit-compiled function.
Worked mini-example
import jax
import jax.numpy as jnp
step_count = 0
def log_norm(arr):
global step_count
step_count += 1
print(f"step {step_count}: L2={float(jnp.linalg.norm(arr)):.4f}")
@jax.jit
def train_step(x):
y = x ** 2
jax.debug.callback(log_norm, y) # fires at runtime
return jnp.sum(y)
train_step(jnp.array([1.0, 2.0, 3.0]))
# prints: step 1: L2=9.2736
Common pitfalls
-
No return value into JAX —
debug.callbackreturnsNoneto the JAX trace. Usejax.pure_callbackwhen you need the Python function’s output back inside the computation. -
Not differentiable — gradients cannot propagate through
debug.callback. Don’t put it on the path of a value you intend to differentiate. -
Different from
io_callback—debug.callbackhas relaxed ordering guarantees;io_callbackenforces sequential I/O ordering. - Performance — each callback introduces a host synchronisation that breaks async dispatch. Avoid in tight inner loops.
Problem
Implement callback_passthrough(x) that:
-
Defines a no-op callback
cb(arr): pass. -
Calls
jax.debug.callback(cb, x)— demonstrating the side-effect pattern. -
Returns
x + 1.
-
x: 1-D JAX array.
Returns: 1-D array — x + 1.
Hints
jax
debug-callback
side-effect
Sign in to attempt this problem and view the solution.