medium primitives

pure_callback Basics

Why this matters

JAX’s function-transformation model (jit, grad, vmap) requires functions to be pure and expressed in JAX primitives. But sometimes you need to call a Python/NumPy function that JAX doesn’t know about — a special function from scipy, a custom C extension, or a third-party library with stable numerical properties.

jax.pure_callback(fn, result_shape, *args) is the bridge. It lets you call any pure Python function from inside a JAX computation — even inside jit — by handing control back to the Python interpreter (the “host”) for that one call. The key constraint: the function must be pure (no side effects, deterministic output for a given input).

Worked mini-example

import jax
import jax.numpy as jnp
import numpy as np
from scipy.special import expit  # logistic sigmoid — not in JAX

def scipy_sigmoid(x):
    result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
    return jax.pure_callback(expit, result_shape, x)

x = jnp.array([0.0, 1.0, -1.0])
print(scipy_sigmoid(x))  # → [0.5, 0.731, 0.269]

Common pitfalls

  • result_shape is mandatory — JAX needs it during tracing to know the output shape and dtype before the host function runs. Forget it and you get a TypeError.
  • fn must be pure — it must not write to global state, print, or produce different outputs for the same input. JAX may call it multiple times (e.g., for vmap vectorisation).
  • Host callbacks are slow — each call breaks JAX’s async dispatch and round-trips to Python. Don’t use in hot loops.
  • Not differentiable by defaultjax.pure_callback is treated as a black box by jax.grad. To add a custom gradient through a callback, use jax.custom_jvp/jax.custom_vjp.

Problem

Implement pure_callback_doubler(x) that uses jax.pure_callback to invoke a NumPy doubler (arr * 2) from inside JAX.

  • x: 1-D jax array.

Returns: 1-D jax array — 2 * x (computed via host callback).

Hints

jax pure-callback interop

Sign in to attempt this problem and view the solution.