medium primitives

jax.debug.callback for Side Effects

Why this matters

jax.debug.callback(fn, *args) lets you call an arbitrary Python function from inside a JAX trace for its side effects only. The callback receives concrete array values at runtime (not abstract tracers) and can write logs, update dashboards, call external APIs, or accumulate training metrics — all from within a jax.jit-compiled function.

Worked mini-example

import jax
import jax.numpy as jnp

step_count = 0

def log_norm(arr):
    global step_count
    step_count += 1
    print(f"step {step_count}: L2={float(jnp.linalg.norm(arr)):.4f}")

@jax.jit
def train_step(x):
    y = x ** 2
    jax.debug.callback(log_norm, y)  # fires at runtime
    return jnp.sum(y)

train_step(jnp.array([1.0, 2.0, 3.0]))
# prints: step 1: L2=9.2736

Common pitfalls

  • No return value into JAXdebug.callback returns None to the JAX trace. Use jax.pure_callback when you need the Python function’s output back inside the computation.
  • Not differentiable — gradients cannot propagate through debug.callback. Don’t put it on the path of a value you intend to differentiate.
  • Different from io_callbackdebug.callback has relaxed ordering guarantees; io_callback enforces sequential I/O ordering.
  • Performance — each callback introduces a host synchronisation that breaks async dispatch. Avoid in tight inner loops.

Problem

Implement callback_passthrough(x) that:

  1. Defines a no-op callback cb(arr): pass.
  2. Calls jax.debug.callback(cb, x) — demonstrating the side-effect pattern.
  3. Returns x + 1.
  • x: 1-D JAX array.

Returns: 1-D array — x + 1.

Hints

jax debug-callback side-effect

Sign in to attempt this problem and view the solution.