hard primitives

io_callback for Side Effects

Why this matters

jax.pure_callback requires a pure function. But sometimes you genuinely need side effects inside a JAX computation β€” logging a loss value to a file, writing metrics to TensorBoard, updating a training progress bar, or recording a step counter. These are inherently impure.

jax.experimental.io_callback(fn, result_shape, *args, ordered=True) is the answer. Unlike pure_callback, it explicitly acknowledges the function is impure (has side effects). Setting ordered=True guarantees that callbacks execute in the same order as the JAX operations that triggered them β€” critical when the order of log writes matters.

Worked mini-example

import jax
import jax.numpy as jnp
from jax.experimental import io_callback

step_log = []

def log_loss(loss_val):
    step_log.append(float(loss_val))
    return loss_val  # pass-through

def training_step(params, x):
    loss = jnp.sum((params * x) ** 2)
    result_shape = jax.ShapeDtypeStruct(loss.shape, loss.dtype)
    logged_loss = io_callback(log_loss, result_shape, loss, ordered=True)
    return logged_loss

Common pitfalls

  • ordered=True is important for logging β€” without it, JAX may execute callbacks in any order (it exploits parallelism). For file writes and sequential logs, always set ordered=True.
  • Not differentiable β€” io_callback is opaque to jax.grad. Use it only for side effects, not for values you need to differentiate through.
  • Do not use in hot loops β€” every callback breaks async dispatch and round-trips to Python. Use sparingly (e.g., every 100 steps).
  • Unlike pure_callback, the fn CAN have side effects β€” that’s the whole point. But keep it fast to avoid stalling JAX.

Problem

Implement io_callback_passthrough(x) that uses jax.experimental.io_callback with ordered=True to call a host function that returns 2 * x. This exercises the io_callback API β€” in real code, the host function would log metrics or write to a file.

  • x: 1-D jax array.

Returns: 1-D jax array β€” 2 * x (computed via ordered io_callback).

Hints

jax io-callback interop

Sign in to attempt this problem and view the solution.