We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Explicit device_put
Why this matters
JAX arrays live on devices — CPUs, GPUs, or TPUs. Normally JAX moves data
to the default device lazily when you first operate on it, but sometimes
you want to be explicit about where an array lives. jax.device_put(x)
places x on the default device (or a specific device/sharding when passed
a second argument).
On a single-device setup this is a clarity-only call — if x is already
on the default device, no copy happens. Its real power shows up in
multi-device code: jax.device_put(x, devices[1]) pins the array to a
particular accelerator, and jax.device_put(x, sharding) distributes it
across a mesh.
Worked mini-example
import jax
import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0]) # created on default device
x_explicit = jax.device_put(x) # explicit placement (no-op here)
# On multi-device:
# devices = jax.devices()
# x_on_gpu1 = jax.device_put(x, devices[1])
Common pitfalls
-
NumPy → device causes a copy: if
xis a NumPy array (not a JAX array),device_putcopies it to device. This is fine once, but calling it inside a tight loop means a copy per iteration — move it out. - Same-device is a no-op: if the array is already on the target device, XLA avoids the copy. Don’t worry about redundant calls outside hot paths.
-
Don’t confuse with
jnp.asarray: both bring NumPy → JAX, butdevice_putis the idiomatic way when device placement is the primary concern.
Problem
Implement device_put_then_double(x) that explicitly places x on the
default device with jax.device_put, then returns 2 * x.
-
x: 1-D array (NumPy or JAX).
Returns: 1-D JAX array of the same shape with all values doubled.
Example (not from the test set):
-
device_put_then_double(jnp.array([5.0, 6.0]))returns[10.0, 12.0].
Hints
Sign in to attempt this problem and view the solution.