medium primitives

PartitionSpec Layout

Why this matters

When a model is too big for one accelerator, you carve each tensor across a device mesh: a logical grid of devices with named axes. Every parameter then declares which mesh axis (if any) it’s sharded along β€” that declaration is a PartitionSpec. Get the spec wrong and the compiler either errors out or, worse, silently replicates a tensor that should have been sharded β€” turning a 70B-parameter model into 70B per device.

What is a PartitionSpec?

jax.sharding.PartitionSpec (alias P) is a tuple where each position maps to one axis of the tensor. The value tells the compiler which mesh axis (if any) to shard along:

from jax.sharding import PartitionSpec as P, Mesh, NamedSharding

devices = mesh_utils.create_device_mesh((2, 4))   # 2x4 grid
mesh = Mesh(devices, axis_names=("data", "model"))

# A Dense kernel of shape (in, out):
spec = P(None, "model")           # axis 0 replicated; axis 1 shards over "model" (4 devices)
sharding = NamedSharding(mesh, spec)
sharded_kernel = jax.device_put(kernel, sharding)

Reading the spec aloud: β€œaxis 0 is replicated everywhere, axis 1 is split across the β€˜model’ mesh axis.” If the mesh’s model axis has 4 devices and out=16, each device holds an (in, 4) slice.

The three cases per axis

For each tensor axis, the spec entry is one of:

entry meaning
None axis is replicated β€” every device holds the full slice along this axis
"name" axis is sharded along the mesh axis named "name" β€” divided by the mesh-axis size
("a", "b") axis is sharded along both mesh axes β€” divided by their product

Standard layouts you’ll see

  • Megatron tensor parallel (column-parallel Dense): P(None, "model") β€” input axis replicated across model-parallel devices, output shards.
  • Row-parallel Dense: P("model", None) β€” input shards, output replicated; needs a final all-reduce.
  • FSDP / data-parallel: P("data", None) β€” batch shards, features replicated.
  • 2-D sharding: P("data", "model") β€” batch on data axis, hidden on model axis; common for MoE layers.

Single-device sandbox

A real shard is created by jax.device_put(x, NamedSharding(mesh, spec)) on a multi-device process. The sandbox has one device and one host β€” mesh_utils.create_device_mesh((2, 4)) would fail because there are only 1 device available.

Instead, this problem asks the accounting question that any engineer placing a PartitionSpec(None, "model") annotation has to compute by hand: what is the shape of one shard?

For a (kernel_in, kernel_out) Dense kernel with PartitionSpec(None, "model") over num_devices model-parallel devices:

  • Axis 0 (kernel_in) is None β†’ replicated β†’ per-shard size stays at kernel_in.
  • Axis 1 (kernel_out) is "model" β†’ sharded across the mesh axis β†’ per-shard size is kernel_out // num_devices.

Assume even division (num_devices divides kernel_out exactly).

Problem

Given (num_devices, kernel_in, kernel_out), compute and return [kernel_in_size, kernel_out_per_shard] as a 1-D (2,) float array, where:

  • kernel_in_size = kernel_in (replicated; same on every shard).
  • kernel_out_per_shard = kernel_out // num_devices (sharded).

Inputs:

  • num_devices: int β€” size of the "model" mesh axis.
  • kernel_in: int β€” input dimension of the kernel.
  • kernel_out: int β€” output dimension; divisible by num_devices.

Output: 1-D (2,) array [kernel_in_size, kernel_out_per_shard].

Hints

flax sharding partition-spec

Sign in to attempt this problem and view the solution.