We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
easy
primitives
JAX <-> NumPy Bridge
Why this matters
JAX arrays live on an accelerator (GPU/TPU) or a JAX-managed CPU
buffer. NumPy arrays live in normal host RAM. The two worlds are
separate: you canβt directly call np.sqrt on a JAX array inside
a jit-compiled function, and you canβt pass a NumPy array to a
JAX function expecting a device array without first converting it.
Understanding the explicit bridge β jax.device_get and
jnp.asarray β is essential for:
-
scipy interop:
scipyfunctions expect NumPy arrays. -
Custom serialization: saving/loading arrays via
np.save/np.load. - Mixed pipelines: combining JAX model outputs with NumPy-based post-processing.
Worked mini-example
import jax
import jax.numpy as jnp
import numpy as np
from scipy.special import expit # NumPy-only
jax_logits = jnp.array([0.0, 1.0, -1.0])
# JAX β NumPy (brings to host)
np_logits = jax.device_get(jax_logits) # numpy.ndarray
# NumPy computation (scipy)
np_probs = expit(np_logits) # sigmoid
# NumPy β JAX (pushes to device)
jax_probs = jnp.asarray(np_probs) # jax.Array
Common pitfalls
-
Roundtrip is slow β
device_gettriggers a DMA transfer over PCIe/NVLink from device to host.jnp.asarraytransfers back. Never do this inside a training loop or inside ajit-compiled function. -
jnp.asarrayvsnp.asarrayβjnp.asarray(np_arr)copies to device;np.asarray(jax_arr)is equivalent todevice_get. -
Implicit conversion β JAX will silently convert NumPy arrays
to JAX arrays on first use, so
jnp.sum(np_arr)works. But explicit conversion makes the intent clear and avoids surprises. -
dtype changes β NumPy defaults to
float64; JAX defaults tofloat32. After the roundtrip, check that dtypes are as expected.
Problem
Implement jax_to_numpy_to_jax(x) that explicitly:
-
Moves
xto host as a NumPy array withjax.device_get(x). -
Doubles it with NumPy arithmetic (
np_x * 2). -
Pushes the result back to JAX with
jnp.asarray(...).
-
x: 1-D jax array.
Returns: 1-D jax array β 2 * x.
Hints
jax
numpy
interop
Sign in to attempt this problem and view the solution.