easy primitives

Explicit device_put

Why this matters

JAX arrays live on devices — CPUs, GPUs, or TPUs. Normally JAX moves data to the default device lazily when you first operate on it, but sometimes you want to be explicit about where an array lives. jax.device_put(x) places x on the default device (or a specific device/sharding when passed a second argument).

On a single-device setup this is a clarity-only call — if x is already on the default device, no copy happens. Its real power shows up in multi-device code: jax.device_put(x, devices[1]) pins the array to a particular accelerator, and jax.device_put(x, sharding) distributes it across a mesh.

Worked mini-example

import jax
import jax.numpy as jnp

x = jnp.array([1.0, 2.0, 3.0])           # created on default device
x_explicit = jax.device_put(x)            # explicit placement (no-op here)

# On multi-device:
# devices = jax.devices()
# x_on_gpu1 = jax.device_put(x, devices[1])

Common pitfalls

  • NumPy → device causes a copy: if x is a NumPy array (not a JAX array), device_put copies it to device. This is fine once, but calling it inside a tight loop means a copy per iteration — move it out.
  • Same-device is a no-op: if the array is already on the target device, XLA avoids the copy. Don’t worry about redundant calls outside hot paths.
  • Don’t confuse with jnp.asarray: both bring NumPy → JAX, but device_put is the idiomatic way when device placement is the primary concern.

Problem

Implement device_put_then_double(x) that explicitly places x on the default device with jax.device_put, then returns 2 * x.

  • x: 1-D array (NumPy or JAX).

Returns: 1-D JAX array of the same shape with all values doubled.

Example (not from the test set):

  • device_put_then_double(jnp.array([5.0, 6.0])) returns [10.0, 12.0].

Hints

jax device-put memory

Sign in to attempt this problem and view the solution.