We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Host Roundtrip via device_get
Why this matters
Sometimes you must leave JAX — to call a NumPy-only library, to
serialise an array, or to do in-place mutations that JAX’s functional model
forbids. jax.device_get(x) copies the array from the accelerator to host
memory and returns a plain NumPy array. jnp.asarray(np_array) sends it
back to device.
This round-trip is expensive: two PCIe (or NVLink/XBar) transfers, plus a synchronisation barrier. On a GPU the round-trip easily adds milliseconds per call. In a training loop running 1,000 steps per second, that is seconds of wasted time per hour.
Understanding the pattern — and its cost — helps you recognise when code is
accidentally leaking to host (e.g., accidentally wrapping a result in
np.array(...) inside a loss function).
Worked mini-example
import jax
import jax.numpy as jnp
import numpy as np
x = jnp.array([1.0, 2.0, 3.0])
# Pull to host
np_x = jax.device_get(x) # np.ndarray on CPU
print(type(np_x)) # <class 'numpy.ndarray'>
# Do host work
np_doubled = np_x * 2 # pure NumPy
# Push back to device
jax_result = jnp.asarray(np_doubled)
Common pitfalls
-
Never do this in a hot path: device↔host transfers stall the GPU
pipeline. Profile first; if you see unexpected
CopyToHostops in your XLA trace, something is triggering implicitdevice_get. -
Implicit
device_getvianp.array(jax_array): wrapping a JAX array innp.array(...)performs an implicit round-trip. This is easy to do accidentally inside callbacks. -
Ask: can this be done in JAX? Before using
device_get, check if the operation can be expressed withjnp,jax.lax, or a custom JVP. Usually it can.
Problem
Implement host_roundtrip(x) that:
-
Brings
xto host withjax.device_get. - Doubles the NumPy array on the host.
-
Pushes the result back to device with
jnp.asarray.
-
x: 1-D JAX array.
Returns: 1-D JAX array of the same shape with all values doubled.
Example (not from the test set):
-
host_roundtrip(jnp.array([5.0, 10.0]))returns[10.0, 20.0].
Hints
Sign in to attempt this problem and view the solution.