We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
io_callback for Side Effects
Why this matters
jax.pure_callback requires a pure function. But sometimes you
genuinely need side effects inside a JAX computation β logging a
loss value to a file, writing metrics to TensorBoard, updating a
training progress bar, or recording a step counter. These are
inherently impure.
jax.experimental.io_callback(fn, result_shape, *args, ordered=True)
is the answer. Unlike pure_callback, it explicitly acknowledges the
function is impure (has side effects). Setting ordered=True
guarantees that callbacks execute in the same order as the JAX
operations that triggered them β critical when the order of log
writes matters.
Worked mini-example
import jax
import jax.numpy as jnp
from jax.experimental import io_callback
step_log = []
def log_loss(loss_val):
step_log.append(float(loss_val))
return loss_val # pass-through
def training_step(params, x):
loss = jnp.sum((params * x) ** 2)
result_shape = jax.ShapeDtypeStruct(loss.shape, loss.dtype)
logged_loss = io_callback(log_loss, result_shape, loss, ordered=True)
return logged_loss
Common pitfalls
-
ordered=Trueis important for logging β without it, JAX may execute callbacks in any order (it exploits parallelism). For file writes and sequential logs, always setordered=True. -
Not differentiable β
io_callbackis opaque tojax.grad. Use it only for side effects, not for values you need to differentiate through. - Do not use in hot loops β every callback breaks async dispatch and round-trips to Python. Use sparingly (e.g., every 100 steps).
-
Unlike
pure_callback, the fn CAN have side effects β thatβs the whole point. But keep it fast to avoid stalling JAX.
Problem
Implement io_callback_passthrough(x) that uses
jax.experimental.io_callback with ordered=True to call a host
function that returns 2 * x. This exercises the io_callback API β
in real code, the host function would log metrics or write to a file.
-
x: 1-D jax array.
Returns: 1-D jax array β 2 * x (computed via ordered io_callback).
Hints
Sign in to attempt this problem and view the solution.