medium primitives

NNX State Sharding

Why this matters

NNX models live as Python objects, but to shard them across devices you need their state as a pytree — same shape as a regular JAX pytree of arrays. The trick is that the sharding layout itself is also a pytree, with the same structure as the state, but every leaf is a PartitionSpec instead of an array. JAX uses these paired pytrees to place the state across the mesh.

The pattern is universal:

graphdef, state = nnx.split(model)
sharding_tree = jax.tree_util.tree_map(pick_spec, state)
sharded_state = jax.tree_util.tree_map(
    lambda x, spec: jax.device_put(x, NamedSharding(mesh, spec)),
    state, sharding_tree,
)
sharded_model = nnx.merge(graphdef, sharded_state)

Two pytrees of the same structure operated on in lockstep — this is the heart of how jax.jit(out_shardings=...) and friends work.

Why split first

NNX modules have Python-level structure (the Module graph) and array state (the Variable values). nnx.split(model) peels them apart:

  • graphdef — the bare structure: nested Linears, Convs, attribute names, no array data. Treated as a static input by JIT.
  • state — a pytree whose leaves are the actual arrays (kernel.value, bias.value, etc.).

JAX cares about pytrees of arrays. Sharding can only be applied to the array side, so we work on state. After sharding, we glue the model back with nnx.merge(graphdef, sharded_state).

Picking specs by leaf

A simple nnx.Linear(in_dim, out_dim) has two leaves:

  • kernel: shape (in_dim, out_dim) — choose P(None, "model") to shard the output dim.
  • bias: shape (out_dim,) — choose P("model") to shard along the same mesh axis (otherwise the bias and kernel disagree on where the per-device slice lives).

The pick_spec function inspects each leaf’s rank to decide:

def pick_spec(leaf):
    if leaf.ndim == 2:
        return P(None, "model")    # kernel
    return P("model")              # bias (1-D)

For larger models (Conv, MultiHeadAttention), you’d write more branches — or use nnx.with_partitioning at module-construction time to attach the spec to each Variable directly.

Why uniform structure matters

Because state and sharding_tree share the same pytree shape, any tree_map operation works the same way on both. You can:

  • Count leaves (len(jax.tree_util.tree_leaves(state))).
  • Inspect ranks, dtypes, sizes uniformly.
  • Pair them in a single tree_map call.

For an nnx.Linear, that count is 2: one entry per Variable (kernel + bias). Add a LayerNorm to the model and the count becomes 4 (kernel, bias, scale, bias).

Common pitfalls

  • Sharding without splitting: tree_map on a Module silently walks Python attributes that aren’t part of the pytree. Always nnx.split first.
  • Mismatched specs: kernel sharded on "model" but bias replicated leads to silently broken matmuls. Pair them.
  • Forgetting to merge back: a sharded state alone isn’t a callable model — nnx.merge(graphdef, sharded_state) rebuilds the Module with sharded variables.

Problem

Implement state_to_sharded_layout(seed, x, features):

  1. Build model = nnx.Linear(x.shape[-1], int(features), rngs=nnx.Rngs(int(seed))).
  2. Split with _, state = nnx.split(model).
  3. Build a parallel sharding pytree where each leaf is a PartitionSpec: P(None, "model") for the 2-D kernel, P("model") for the 1-D bias. Use jax.tree_util.tree_map.
  4. Count the leaves of the sharding tree: n_leaves = len(jax.tree_util.tree_leaves(sharding_tree)).
  5. Return jnp.array([float(n_leaves)]) as a 1-D (1,) array.

For nnx.Linear, this should always be 2.0 — kernel + bias.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, in_dim) (only its last-dim is used).
  • features: float (cast to int) — output dim of Linear.

Output: 1-D (1,)[num_leaves_in_sharding_tree].

Hints

flax nnx sharding pytree state

Sign in to attempt this problem and view the solution.