medium primitives

Mesh Creation

Why this matters

In JAXโ€™s distributed-computing model, a Mesh is the device topology: a multi-dimensional grid of physical devices (CPUs, GPUs, or TPUs) where each grid axis is given a name. Every sharding strategy you write later โ€” with PartitionSpec and NamedSharding โ€” refers to these axis names to say which dimension of an array maps to which group of devices.

import jax, numpy as np

devices = np.array(jax.devices()[:4]).reshape(2, 2)
mesh = jax.sharding.Mesh(devices, ('batch', 'model'))
# A 2ร—2 mesh: axis 'batch' has 2 slices, axis 'model' has 2 slices.

The Mesh object owns the mapping from axis names to device slices. len(mesh.axis_names) is the number of named axes (i.e. the number of dimensions in the device grid).

Worked mini-example

import jax
import jax.numpy as jnp
import numpy as np

devices = np.array(jax.devices()[:2])          # 1-D array of 2 devices
mesh = jax.sharding.Mesh(devices, ('data',))   # 1-D mesh, axis named 'data'
print(mesh.axis_names)                          # ('data',)
print(len(mesh.axis_names))                     # 1

Common pitfalls

  • Devices must be a NumPy array, not a Python list. Passing a plain list raises a TypeError. Use np.array(jax.devices()[:n]).
  • axis_names is a tuple of strings, one per mesh dimension. A 1-D mesh needs a 1-element tuple: ('data',).
  • Always slice jax.devices() to the desired count ([:n]). On the single-CPU test runner jax.devices() returns exactly one device, so use n=1.

Problem

Implement mesh_axis_count(n_devices) that:

  1. Casts n_devices to int.
  2. Takes the first n devices with jax.devices()[:n] and wraps them in a NumPy array.
  3. Creates a 1-D jax.sharding.Mesh with axis name 'data'.
  4. Returns jnp.float32(len(mesh.axis_names)) โ€” always 1.0 for a 1-D mesh.

Single-device caveat: the test runner has exactly 1 CPU device, so n_devices is always 1.0. The API is identical regardless of device count; this problem demonstrates it on 1 device.

  • n_devices: scalar (cast to int).

Returns: scalar float32 โ€” number of mesh axes (always 1.0 on this runner).

Example (not from the test set):

  • mesh_axis_count(1.0) โ†’ 1.0

Hints

jax mesh sharding

Sign in to attempt this problem and view the solution.