We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Mesh Creation
Why this matters
In JAXโs distributed-computing model, a Mesh is the device topology: a
multi-dimensional grid of physical devices (CPUs, GPUs, or TPUs) where each
grid axis is given a name. Every sharding strategy you write later โ with
PartitionSpec and NamedSharding โ refers to these axis names to say
which dimension of an array maps to which group of devices.
import jax, numpy as np
devices = np.array(jax.devices()[:4]).reshape(2, 2)
mesh = jax.sharding.Mesh(devices, ('batch', 'model'))
# A 2ร2 mesh: axis 'batch' has 2 slices, axis 'model' has 2 slices.
The Mesh object owns the mapping from axis names to device slices.
len(mesh.axis_names) is the number of named axes (i.e. the number of
dimensions in the device grid).
Worked mini-example
import jax
import jax.numpy as jnp
import numpy as np
devices = np.array(jax.devices()[:2]) # 1-D array of 2 devices
mesh = jax.sharding.Mesh(devices, ('data',)) # 1-D mesh, axis named 'data'
print(mesh.axis_names) # ('data',)
print(len(mesh.axis_names)) # 1
Common pitfalls
-
Devices must be a NumPy array, not a Python list. Passing a plain list
raises a
TypeError. Usenp.array(jax.devices()[:n]). -
axis_namesis a tuple of strings, one per mesh dimension. A 1-D mesh needs a 1-element tuple:('data',). -
Always slice
jax.devices()to the desired count ([:n]). On the single-CPU test runnerjax.devices()returns exactly one device, so usen=1.
Problem
Implement mesh_axis_count(n_devices) that:
-
Casts
n_devicestoint. -
Takes the first
ndevices withjax.devices()[:n]and wraps them in a NumPy array. -
Creates a 1-D
jax.sharding.Meshwith axis name'data'. -
Returns
jnp.float32(len(mesh.axis_names))โ always1.0for a 1-D mesh.
Single-device caveat: the test runner has exactly 1 CPU device, so
n_devices is always 1.0. The API is identical regardless of device count;
this problem demonstrates it on 1 device.
-
n_devices: scalar (cast toint).
Returns: scalar float32 โ number of mesh axes (always 1.0 on this runner).
Example (not from the test set):
-
mesh_axis_count(1.0)โ1.0
Hints
Sign in to attempt this problem and view the solution.