We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
PartitionSpec Basics
Why this matters
Once you have a Mesh, you need to say how each axis of your array maps
to each axis of the mesh. That mapping is a PartitionSpec.
PartitionSpec(axis_name_or_None, ...) has one entry per array dimension:
-
A string (e.g.
'data') means shard this array axis across the named mesh axis. With 8 devices on axis'data', each device holds 1/8 of the rows. -
Nonemeans replicate this array axis โ every device gets a full copy of that dimension.
PartitionSpec('data', None)
# array axis 0 โ sharded across mesh axis 'data'
# array axis 1 โ fully replicated
The PartitionSpec is combined with a Mesh inside NamedSharding, and then
jax.device_put(x, sharding) physically places (or annotates) the array.
Worked mini-example
import jax, numpy as np, jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
sharding = NamedSharding(mesh, PartitionSpec('data', None))
x = jnp.ones((4, 3))
x_sharded = jax.device_put(x, sharding)
# x_sharded has the same values as x; on >1 devices rows would be split.
Common pitfalls
-
PartitionSpec length must match array
ndim. A 2-D array needs a 2-entry spec. PassingPartitionSpec('data',)to a 2-D array raises aValueError. -
Noneis not a string. Use PythonNone, not the string'None'. -
Axis name must exist in the mesh. Referring to
'model'when the mesh only has'data'raises aValueError. - On 1 device, sharding is metadata only โ values are identical to the input. This problem verifies that the call succeeds and values are preserved.
Problem
Implement shard_first_axis(x) that:
-
Creates a 1-device
Meshwith axis name'data'. -
Builds
NamedSharding(mesh, PartitionSpec('data', None)). -
Returns
jax.device_put(x, sharding).
Single-device caveat: the test runner has exactly 1 CPU device. Sharding is pure metadata on 1 device โ the returned array has the same values as the input. The API call is identical on a multi-device host; only the physical layout differs.
-
x: 2-D JAX array.
Returns: 2-D array of the same shape and values as x.
Examples (not from the test set):
-
shard_first_axis(jnp.ones((2, 3)))โjnp.ones((2, 3))(values unchanged)
Hints
Sign in to attempt this problem and view the solution.