medium primitives

NNX PartitionSpec Layout

Why this matters

A 70B-parameter transformer doesn’t fit on one accelerator. The way out is to shard every tensor across a logical grid of devices — a Mesh — and tell the compiler which axes split where via a PartitionSpec. Get the spec wrong and the compiler either errors out or, worse, silently replicates a parameter that should have been sharded — turning a 70B model into “70B per device”, which OOMs on every chip.

Whether you write Linen or NNX, PartitionSpec is the same JAX primitive — the sharding language is independent of the module library. NNX layers take the same Mesh and NamedSharding objects you’d use with raw JAX.

What is a PartitionSpec?

jax.sharding.PartitionSpec (alias P) is a tuple where each position maps to one axis of the tensor. The value tells the compiler which mesh axis (if any) to shard along:

from jax.sharding import PartitionSpec as P, Mesh, NamedSharding
from jax.experimental import mesh_utils

devices = mesh_utils.create_device_mesh((2, 4))   # 2x4 grid
mesh = Mesh(devices, axis_names=("data", "model"))

# An nnx.Linear's kernel has shape (in, out):
spec = P(None, "model")
sharding = NamedSharding(mesh, spec)
sharded_kernel = jax.device_put(kernel, sharding)

Reading the spec aloud: “axis 0 is replicated everywhere, axis 1 is split across the ‘model’ mesh axis.” If the mesh’s model axis has 4 devices and out=16, each device holds an (in, 4) slice.

The three cases per axis

For each tensor axis, the spec entry is one of:

entry meaning
None axis is replicated — every device holds the full slice along this axis
"name" axis is sharded along the mesh axis named "name" — divided by the mesh-axis size
("a", "b") axis is sharded along both mesh axes — divided by their product

Standard layouts you’ll see

  • Megatron column-parallel Linear: P(None, "model") — input replicated across model-parallel devices, output shards.
  • Row-parallel Linear: P("model", None) — input shards, output replicated; needs a final all-reduce.
  • FSDP / data-parallel: P("data", None) — batch shards, features replicated.
  • 2-D sharding: P("data", "model") — batch on data axis, hidden on model axis; common for MoE layers.

Single-device sandbox

A real shard is created by jax.device_put(x, NamedSharding(mesh, spec)) on a multi-device process. The sandbox has one device — there’s no (2, 4) mesh to be had.

Instead, this problem asks the accounting question that any engineer placing a PartitionSpec(None, "model") annotation has to compute by hand: what is the shape of one shard?

For an nnx.Linear(in, out) kernel of shape (kernel_in, kernel_out) with PartitionSpec(None, "model") over num_devices model-parallel devices:

  • Axis 0 (kernel_in) is None → replicated → per-shard size stays at kernel_in.
  • Axis 1 (kernel_out) is "model" → sharded across the mesh axis → per-shard size is kernel_out // num_devices.

Assume even division (num_devices divides kernel_out exactly).

Common pitfalls

  • Off-by-one positionally: PartitionSpec positions match tensor axes left-to-right. Mismatching the rank silently broadcasts.
  • Confusing mesh-axis size with tensor-axis size: the divisor is the mesh axis named "model", not the kernel’s output dim.
  • Forgetting None for replicated axes: leaving an axis out shortens the spec, but the rank no longer matches the tensor.

Problem

Given (num_devices, kernel_in, kernel_out), compute and return [kernel_in_size, kernel_out_per_shard] as a 1-D (2,) float array, where:

  • kernel_in_size = kernel_in (replicated; same on every shard).
  • kernel_out_per_shard = kernel_out // num_devices (sharded).

Inputs:

  • num_devices: float (cast to int) — size of the "model" mesh axis.
  • kernel_in: float (cast to int) — input dimension of the kernel.
  • kernel_out: float (cast to int) — output dimension; divisible by num_devices.

Output: 1-D (2,) array — [kernel_in_size, kernel_out_per_shard].

Hints

flax nnx sharding partition-spec

Sign in to attempt this problem and view the solution.