medium primitives

jax.debug.callback

Why this matters

Sometimes you need to call arbitrary Python from inside a JAX trace — not to return a value, but for a side effect: writing a log line, updating a dashboard, or calling an external API.

jax.debug.callback(fn, *args) lets you do exactly that. Unlike jax.pure_callback, the callback is not required to be pure. It can have side effects. It does not return a value into the JAX computation; it runs purely for effect.

Worked mini-example

import jax
import jax.numpy as jnp

step = 0

def log_step(arr):
    global step
    step += 1
    print(f"step {step}: mean = {arr.mean()}")

@jax.jit
def train_step(x):
    y = x ** 2
    jax.debug.callback(log_step, y)   # fires at runtime, not trace
    return jnp.sum(y)

train_step(jnp.array([1.0, 2.0, 3.0]))
# prints: step 1: mean = 4.666666666...

Common pitfalls

  • Expecting a return value: debug.callback returns None to JAX — use jax.pure_callback when you need the Python function’s output back in the trace.
  • Callback order: callbacks fire in trace order, interleaved with the rest of the computation. Don’t rely on ordering relative to XLA ops.
  • Not differentiable: gradients cannot flow through debug.callback.
  • Performance: each callback introduces a host-sync that breaks async dispatch. Use sparingly in hot paths.

Problem

Implement debug_callback_passthrough(x) that:

  1. Defines a no-op callback cb(arr): pass.
  2. Calls jax.debug.callback(cb, x) — demonstrating the pattern.
  3. Returns x + 1.
  • x: 1-D JAX array.

Returns: 1-D array — x + 1.

Hints

jax debug-callback debugging

Sign in to attempt this problem and view the solution.