We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
pure_callback Basics
Why this matters
JAX’s function-transformation model (jit, grad, vmap) requires
functions to be pure and expressed in JAX primitives. But sometimes
you need to call a Python/NumPy function that JAX doesn’t know
about — a special function from scipy, a custom C extension, or a
third-party library with stable numerical properties.
jax.pure_callback(fn, result_shape, *args) is the bridge. It lets
you call any pure Python function from inside a JAX computation —
even inside jit — by handing control back to the Python interpreter
(the “host”) for that one call. The key constraint: the function
must be pure (no side effects, deterministic output for a given
input).
Worked mini-example
import jax
import jax.numpy as jnp
import numpy as np
from scipy.special import expit # logistic sigmoid — not in JAX
def scipy_sigmoid(x):
result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
return jax.pure_callback(expit, result_shape, x)
x = jnp.array([0.0, 1.0, -1.0])
print(scipy_sigmoid(x)) # → [0.5, 0.731, 0.269]
Common pitfalls
-
result_shapeis mandatory — JAX needs it during tracing to know the output shape and dtype before the host function runs. Forget it and you get aTypeError. - fn must be pure — it must not write to global state, print, or produce different outputs for the same input. JAX may call it multiple times (e.g., for vmap vectorisation).
- Host callbacks are slow — each call breaks JAX’s async dispatch and round-trips to Python. Don’t use in hot loops.
-
Not differentiable by default —
jax.pure_callbackis treated as a black box byjax.grad. To add a custom gradient through a callback, usejax.custom_jvp/jax.custom_vjp.
Problem
Implement pure_callback_doubler(x) that uses jax.pure_callback
to invoke a NumPy doubler (arr * 2) from inside JAX.
-
x: 1-D jax array.
Returns: 1-D jax array — 2 * x (computed via host callback).
Hints
Sign in to attempt this problem and view the solution.