medium primitives

Host Roundtrip via device_get

Why this matters

Sometimes you must leave JAX — to call a NumPy-only library, to serialise an array, or to do in-place mutations that JAX’s functional model forbids. jax.device_get(x) copies the array from the accelerator to host memory and returns a plain NumPy array. jnp.asarray(np_array) sends it back to device.

This round-trip is expensive: two PCIe (or NVLink/XBar) transfers, plus a synchronisation barrier. On a GPU the round-trip easily adds milliseconds per call. In a training loop running 1,000 steps per second, that is seconds of wasted time per hour.

Understanding the pattern — and its cost — helps you recognise when code is accidentally leaking to host (e.g., accidentally wrapping a result in np.array(...) inside a loss function).

Worked mini-example

import jax
import jax.numpy as jnp
import numpy as np

x = jnp.array([1.0, 2.0, 3.0])

# Pull to host
np_x = jax.device_get(x)          # np.ndarray on CPU
print(type(np_x))                  # <class 'numpy.ndarray'>

# Do host work
np_doubled = np_x * 2             # pure NumPy

# Push back to device
jax_result = jnp.asarray(np_doubled)

Common pitfalls

  • Never do this in a hot path: device↔host transfers stall the GPU pipeline. Profile first; if you see unexpected CopyToHost ops in your XLA trace, something is triggering implicit device_get.
  • Implicit device_get via np.array(jax_array): wrapping a JAX array in np.array(...) performs an implicit round-trip. This is easy to do accidentally inside callbacks.
  • Ask: can this be done in JAX? Before using device_get, check if the operation can be expressed with jnp, jax.lax, or a custom JVP. Usually it can.

Problem

Implement host_roundtrip(x) that:

  1. Brings x to host with jax.device_get.
  2. Doubles the NumPy array on the host.
  3. Pushes the result back to device with jnp.asarray.
  • x: 1-D JAX array.

Returns: 1-D JAX array of the same shape with all values doubled.

Example (not from the test set):

  • host_roundtrip(jnp.array([5.0, 10.0])) returns [10.0, 20.0].

Hints

jax device-get host

Sign in to attempt this problem and view the solution.