easy primitives

JAX <-> NumPy Bridge

Why this matters

JAX arrays live on an accelerator (GPU/TPU) or a JAX-managed CPU buffer. NumPy arrays live in normal host RAM. The two worlds are separate: you can’t directly call np.sqrt on a JAX array inside a jit-compiled function, and you can’t pass a NumPy array to a JAX function expecting a device array without first converting it.

Understanding the explicit bridge β€” jax.device_get and jnp.asarray β€” is essential for:

  • scipy interop: scipy functions expect NumPy arrays.
  • Custom serialization: saving/loading arrays via np.save/np.load.
  • Mixed pipelines: combining JAX model outputs with NumPy-based post-processing.

Worked mini-example

import jax
import jax.numpy as jnp
import numpy as np
from scipy.special import expit  # NumPy-only

jax_logits = jnp.array([0.0, 1.0, -1.0])

# JAX β†’ NumPy (brings to host)
np_logits = jax.device_get(jax_logits)   # numpy.ndarray

# NumPy computation (scipy)
np_probs = expit(np_logits)              # sigmoid

# NumPy β†’ JAX (pushes to device)
jax_probs = jnp.asarray(np_probs)       # jax.Array

Common pitfalls

  • Roundtrip is slow β€” device_get triggers a DMA transfer over PCIe/NVLink from device to host. jnp.asarray transfers back. Never do this inside a training loop or inside a jit-compiled function.
  • jnp.asarray vs np.asarray β€” jnp.asarray(np_arr) copies to device; np.asarray(jax_arr) is equivalent to device_get.
  • Implicit conversion β€” JAX will silently convert NumPy arrays to JAX arrays on first use, so jnp.sum(np_arr) works. But explicit conversion makes the intent clear and avoids surprises.
  • dtype changes β€” NumPy defaults to float64; JAX defaults to float32. After the roundtrip, check that dtypes are as expected.

Problem

Implement jax_to_numpy_to_jax(x) that explicitly:

  1. Moves x to host as a NumPy array with jax.device_get(x).
  2. Doubles it with NumPy arithmetic (np_x * 2).
  3. Pushes the result back to JAX with jnp.asarray(...).
  • x: 1-D jax array.

Returns: 1-D jax array β€” 2 * x.

Hints

jax numpy interop

Sign in to attempt this problem and view the solution.