We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
PartitionSpec Layout
Why this matters
When a model is too big for one accelerator, you carve each tensor
across a device mesh: a logical grid of devices with named axes.
Every parameter then declares which mesh axis (if any) itβs sharded
along β that declaration is a PartitionSpec. Get the spec wrong
and the compiler either errors out or, worse, silently replicates a
tensor that should have been sharded β turning a 70B-parameter model
into 70B per device.
What is a PartitionSpec?
jax.sharding.PartitionSpec (alias P) is a tuple where each
position maps to one axis of the tensor. The value tells the
compiler which mesh axis (if any) to shard along:
from jax.sharding import PartitionSpec as P, Mesh, NamedSharding
devices = mesh_utils.create_device_mesh((2, 4)) # 2x4 grid
mesh = Mesh(devices, axis_names=("data", "model"))
# A Dense kernel of shape (in, out):
spec = P(None, "model") # axis 0 replicated; axis 1 shards over "model" (4 devices)
sharding = NamedSharding(mesh, spec)
sharded_kernel = jax.device_put(kernel, sharding)
Reading the spec aloud: βaxis 0 is replicated everywhere, axis 1
is split across the βmodelβ mesh axis.β If the meshβs model axis
has 4 devices and out=16, each device holds an (in, 4) slice.
The three cases per axis
For each tensor axis, the spec entry is one of:
| entry | meaning |
|---|---|
None |
axis is replicated β every device holds the full slice along this axis |
"name" |
axis is sharded along the mesh axis named "name" β divided by the mesh-axis size |
("a", "b") |
axis is sharded along both mesh axes β divided by their product |
Standard layouts youβll see
-
Megatron tensor parallel (column-parallel Dense):
P(None, "model")β input axis replicated across model-parallel devices, output shards. -
Row-parallel Dense:
P("model", None)β input shards, output replicated; needs a final all-reduce. -
FSDP / data-parallel:
P("data", None)β batch shards, features replicated. -
2-D sharding:
P("data", "model")β batch on data axis, hidden on model axis; common for MoE layers.
Single-device sandbox
A real shard is created by jax.device_put(x, NamedSharding(mesh, spec))
on a multi-device process. The sandbox has one device and one host β
mesh_utils.create_device_mesh((2, 4)) would fail because there are
only 1 device available.
Instead, this problem asks the accounting question that any engineer
placing a PartitionSpec(None, "model") annotation has to compute by
hand: what is the shape of one shard?
For a (kernel_in, kernel_out) Dense kernel with
PartitionSpec(None, "model") over num_devices model-parallel
devices:
-
Axis 0 (
kernel_in) isNoneβ replicated β per-shard size stays atkernel_in. -
Axis 1 (
kernel_out) is"model"β sharded across the mesh axis β per-shard size iskernel_out // num_devices.
Assume even division (num_devices divides kernel_out exactly).
Problem
Given (num_devices, kernel_in, kernel_out), compute and return
[kernel_in_size, kernel_out_per_shard] as a 1-D (2,) float array,
where:
-
kernel_in_size = kernel_in(replicated; same on every shard). -
kernel_out_per_shard = kernel_out // num_devices(sharded).
Inputs:
-
num_devices: int β size of the"model"mesh axis. -
kernel_in: int β input dimension of the kernel. -
kernel_out: int β output dimension; divisible bynum_devices.
Output: 1-D (2,) array [kernel_in_size, kernel_out_per_shard].
Hints
Sign in to attempt this problem and view the solution.