We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX PartitionSpec Layout
Why this matters
A 70B-parameter transformer doesn’t fit on one accelerator. The way
out is to shard every tensor across a logical grid of devices —
a Mesh — and tell the compiler which axes split where via a
PartitionSpec. Get the spec wrong and the compiler either errors
out or, worse, silently replicates a parameter that should have
been sharded — turning a 70B model into “70B per device”, which
OOMs on every chip.
Whether you write Linen or NNX, PartitionSpec is the same JAX
primitive — the sharding language is independent of the module
library. NNX layers take the same Mesh and NamedSharding
objects you’d use with raw JAX.
What is a PartitionSpec?
jax.sharding.PartitionSpec (alias P) is a tuple where each
position maps to one axis of the tensor. The value tells the
compiler which mesh axis (if any) to shard along:
from jax.sharding import PartitionSpec as P, Mesh, NamedSharding
from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((2, 4)) # 2x4 grid
mesh = Mesh(devices, axis_names=("data", "model"))
# An nnx.Linear's kernel has shape (in, out):
spec = P(None, "model")
sharding = NamedSharding(mesh, spec)
sharded_kernel = jax.device_put(kernel, sharding)
Reading the spec aloud: “axis 0 is replicated everywhere, axis 1
is split across the ‘model’ mesh axis.” If the mesh’s model axis
has 4 devices and out=16, each device holds an (in, 4) slice.
The three cases per axis
For each tensor axis, the spec entry is one of:
| entry | meaning |
|---|---|
None |
axis is replicated — every device holds the full slice along this axis |
"name" |
axis is sharded along the mesh axis named "name" — divided by the mesh-axis size |
("a", "b") |
axis is sharded along both mesh axes — divided by their product |
Standard layouts you’ll see
-
Megatron column-parallel Linear:
P(None, "model")— input replicated across model-parallel devices, output shards. -
Row-parallel Linear:
P("model", None)— input shards, output replicated; needs a final all-reduce. -
FSDP / data-parallel:
P("data", None)— batch shards, features replicated. -
2-D sharding:
P("data", "model")— batch on data axis, hidden on model axis; common for MoE layers.
Single-device sandbox
A real shard is created by jax.device_put(x, NamedSharding(mesh, spec))
on a multi-device process. The sandbox has one device — there’s no
(2, 4) mesh to be had.
Instead, this problem asks the accounting question that any
engineer placing a PartitionSpec(None, "model") annotation has
to compute by hand: what is the shape of one shard?
For an nnx.Linear(in, out) kernel of shape (kernel_in, kernel_out)
with PartitionSpec(None, "model") over num_devices model-parallel
devices:
-
Axis 0 (
kernel_in) isNone→ replicated → per-shard size stays atkernel_in. -
Axis 1 (
kernel_out) is"model"→ sharded across the mesh axis → per-shard size iskernel_out // num_devices.
Assume even division (num_devices divides kernel_out exactly).
Common pitfalls
-
Off-by-one positionally:
PartitionSpecpositions match tensor axes left-to-right. Mismatching the rank silently broadcasts. -
Confusing mesh-axis size with tensor-axis size: the divisor
is the mesh axis named
"model", not the kernel’s output dim. -
Forgetting
Nonefor replicated axes: leaving an axis out shortens the spec, but the rank no longer matches the tensor.
Problem
Given (num_devices, kernel_in, kernel_out), compute and return
[kernel_in_size, kernel_out_per_shard] as a 1-D (2,) float
array, where:
-
kernel_in_size = kernel_in(replicated; same on every shard). -
kernel_out_per_shard = kernel_out // num_devices(sharded).
Inputs:
-
num_devices: float (cast to int) — size of the"model"mesh axis. -
kernel_in: float (cast to int) — input dimension of the kernel. -
kernel_out: float (cast to int) — output dimension; divisible bynum_devices.
Output: 1-D (2,) array — [kernel_in_size, kernel_out_per_shard].
Hints
Sign in to attempt this problem and view the solution.