We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Multi-Host JAX (Conceptual)
Why this matters
Single-host JAX gives you control over all devices on one machine. For large-scale training — models too large for a single host’s memory — you need multiple machines (hosts), each with multiple accelerators. JAX supports this via multi-host distributed initialization.
The setup is simple but strict:
import jax
# Call BEFORE any JAX computation:
jax.distributed.initialize(
coordinator_address="10.0.0.1:1234", # host 0's address
num_processes=4, # total number of hosts
process_id=0 # this host's rank
)
# After init, jax.devices() returns ALL devices across ALL hosts:
print(jax.device_count()) # e.g. 32 (4 hosts × 8 GPUs each)
print(jax.local_device_count()) # 8 (devices on THIS host)
Once initialized, Mesh, NamedSharding, and jax.jit with
in_shardings / out_shardings shard data and computation across the
full distributed device set transparently.
Worked mini-example
import jax
# Single-host (no init needed):
print(jax.device_count()) # 1 on the test runner CPU
print(jax.local_device_count()) # 1
# Multi-host (requires actual multi-host setup):
# jax.distributed.initialize("host0:1234", num_processes=4, process_id=my_rank)
# print(jax.device_count()) # 4 × local_device_count
Common pitfalls
-
Must call
initializeBEFORE any JAX use: JAX captures the backend at first use; callinginitializeafter that raises an error. -
All processes must call
initializesimultaneously: it is a collective — if one process hangs, all hang. -
coordinator_addressmust be reachable: the coordinator host and port must be accessible from all workers. -
process_idmust be unique: each process gets a distinct integer[0, num_processes).
Problem
Implement n_devices_in_mesh(dummy) that returns jax.device_count() as
a jnp.float32 scalar. The dummy argument is ignored — it exists to
satisfy the test framework’s single-argument contract.
On the single-CPU test runner, jax.device_count() is always 1.
Returns: jnp.float32(jax.device_count()).
Examples (not from the test set):
-
n_devices_in_mesh(0.0)→1.0(on the test runner)
Hints
Sign in to attempt this problem and view the solution.