medium primitives

PartitionSpec Basics

Why this matters

Once you have a Mesh, you need to say how each axis of your array maps to each axis of the mesh. That mapping is a PartitionSpec.

PartitionSpec(axis_name_or_None, ...) has one entry per array dimension:

  • A string (e.g. 'data') means shard this array axis across the named mesh axis. With 8 devices on axis 'data', each device holds 1/8 of the rows.
  • None means replicate this array axis โ€” every device gets a full copy of that dimension.
PartitionSpec('data', None)
# array axis 0 โ†’ sharded across mesh axis 'data'
# array axis 1 โ†’ fully replicated

The PartitionSpec is combined with a Mesh inside NamedSharding, and then jax.device_put(x, sharding) physically places (or annotates) the array.

Worked mini-example

import jax, numpy as np, jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec

devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
sharding = NamedSharding(mesh, PartitionSpec('data', None))

x = jnp.ones((4, 3))
x_sharded = jax.device_put(x, sharding)
# x_sharded has the same values as x; on >1 devices rows would be split.

Common pitfalls

  • PartitionSpec length must match array ndim. A 2-D array needs a 2-entry spec. Passing PartitionSpec('data',) to a 2-D array raises a ValueError.
  • None is not a string. Use Python None, not the string 'None'.
  • Axis name must exist in the mesh. Referring to 'model' when the mesh only has 'data' raises a ValueError.
  • On 1 device, sharding is metadata only โ€” values are identical to the input. This problem verifies that the call succeeds and values are preserved.

Problem

Implement shard_first_axis(x) that:

  1. Creates a 1-device Mesh with axis name 'data'.
  2. Builds NamedSharding(mesh, PartitionSpec('data', None)).
  3. Returns jax.device_put(x, sharding).

Single-device caveat: the test runner has exactly 1 CPU device. Sharding is pure metadata on 1 device โ€” the returned array has the same values as the input. The API call is identical on a multi-device host; only the physical layout differs.

  • x: 2-D JAX array.

Returns: 2-D array of the same shape and values as x.

Examples (not from the test set):

  • shard_first_axis(jnp.ones((2, 3))) โ†’ jnp.ones((2, 3)) (values unchanged)

Hints

jax partition-spec sharding

Sign in to attempt this problem and view the solution.