hard primitives

Multi-Host JAX (Conceptual)

Why this matters

Single-host JAX gives you control over all devices on one machine. For large-scale training — models too large for a single host’s memory — you need multiple machines (hosts), each with multiple accelerators. JAX supports this via multi-host distributed initialization.

The setup is simple but strict:

import jax

# Call BEFORE any JAX computation:
jax.distributed.initialize(
    coordinator_address="10.0.0.1:1234",  # host 0's address
    num_processes=4,                       # total number of hosts
    process_id=0                           # this host's rank
)

# After init, jax.devices() returns ALL devices across ALL hosts:
print(jax.device_count())   # e.g. 32 (4 hosts × 8 GPUs each)
print(jax.local_device_count())  # 8 (devices on THIS host)

Once initialized, Mesh, NamedSharding, and jax.jit with in_shardings / out_shardings shard data and computation across the full distributed device set transparently.

Worked mini-example

import jax

# Single-host (no init needed):
print(jax.device_count())        # 1 on the test runner CPU
print(jax.local_device_count())  # 1

# Multi-host (requires actual multi-host setup):
# jax.distributed.initialize("host0:1234", num_processes=4, process_id=my_rank)
# print(jax.device_count())  # 4 × local_device_count

Common pitfalls

  • Must call initialize BEFORE any JAX use: JAX captures the backend at first use; calling initialize after that raises an error.
  • All processes must call initialize simultaneously: it is a collective — if one process hangs, all hang.
  • coordinator_address must be reachable: the coordinator host and port must be accessible from all workers.
  • process_id must be unique: each process gets a distinct integer [0, num_processes).

Problem

Implement n_devices_in_mesh(dummy) that returns jax.device_count() as a jnp.float32 scalar. The dummy argument is ignored — it exists to satisfy the test framework’s single-argument contract.

On the single-CPU test runner, jax.device_count() is always 1.

Returns: jnp.float32(jax.device_count()).

Examples (not from the test set):

  • n_devices_in_mesh(0.0)1.0 (on the test runner)

Hints

jax multi-host distributed

Sign in to attempt this problem and view the solution.