We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX State Sharding
Why this matters
NNX models live as Python objects, but to shard them across devices
you need their state as a pytree — same shape as a regular JAX
pytree of arrays. The trick is that the sharding layout itself is
also a pytree, with the same structure as the state, but
every leaf is a PartitionSpec instead of an array. JAX uses these
paired pytrees to place the state across the mesh.
The pattern is universal:
graphdef, state = nnx.split(model)
sharding_tree = jax.tree_util.tree_map(pick_spec, state)
sharded_state = jax.tree_util.tree_map(
lambda x, spec: jax.device_put(x, NamedSharding(mesh, spec)),
state, sharding_tree,
)
sharded_model = nnx.merge(graphdef, sharded_state)
Two pytrees of the same structure operated on in lockstep — this is
the heart of how jax.jit(out_shardings=...) and friends work.
Why split first
NNX modules have Python-level structure (the Module graph) and
array state (the Variable values). nnx.split(model) peels them
apart:
-
graphdef— the bare structure: nestedLinears,Convs, attribute names, no array data. Treated as a static input by JIT. -
state— a pytree whose leaves are the actual arrays (kernel.value,bias.value, etc.).
JAX cares about pytrees of arrays. Sharding can only be applied to
the array side, so we work on state. After sharding, we glue the
model back with nnx.merge(graphdef, sharded_state).
Picking specs by leaf
A simple nnx.Linear(in_dim, out_dim) has two leaves:
-
kernel: shape(in_dim, out_dim)— chooseP(None, "model")to shard the output dim. -
bias: shape(out_dim,)— chooseP("model")to shard along the same mesh axis (otherwise the bias and kernel disagree on where the per-device slice lives).
The pick_spec function inspects each leaf’s rank to decide:
def pick_spec(leaf):
if leaf.ndim == 2:
return P(None, "model") # kernel
return P("model") # bias (1-D)
For larger models (Conv, MultiHeadAttention), you’d write more
branches — or use nnx.with_partitioning at module-construction
time to attach the spec to each Variable directly.
Why uniform structure matters
Because state and sharding_tree share the same pytree shape,
any tree_map operation works the same way on both. You can:
-
Count leaves (
len(jax.tree_util.tree_leaves(state))). - Inspect ranks, dtypes, sizes uniformly.
-
Pair them in a single
tree_mapcall.
For an nnx.Linear, that count is 2: one entry per Variable
(kernel + bias). Add a LayerNorm to the model and the count
becomes 4 (kernel, bias, scale, bias).
Common pitfalls
-
Sharding without splitting:
tree_mapon a Module silently walks Python attributes that aren’t part of the pytree. Alwaysnnx.splitfirst. -
Mismatched specs:
kernelsharded on"model"butbiasreplicated leads to silently broken matmuls. Pair them. -
Forgetting to merge back: a sharded
statealone isn’t a callable model —nnx.merge(graphdef, sharded_state)rebuilds the Module with sharded variables.
Problem
Implement state_to_sharded_layout(seed, x, features):
-
Build
model = nnx.Linear(x.shape[-1], int(features), rngs=nnx.Rngs(int(seed))). -
Split with
_, state = nnx.split(model). -
Build a parallel sharding pytree where each leaf is a
PartitionSpec:P(None, "model")for the 2-Dkernel,P("model")for the 1-Dbias. Usejax.tree_util.tree_map. -
Count the leaves of the sharding tree:
n_leaves = len(jax.tree_util.tree_leaves(sharding_tree)). -
Return
jnp.array([float(n_leaves)])as a 1-D(1,)array.
For nnx.Linear, this should always be 2.0 — kernel + bias.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, in_dim)(only its last-dim is used). -
features: float (cast to int) — output dim of Linear.
Output: 1-D (1,) — [num_leaves_in_sharding_tree].
Hints
Sign in to attempt this problem and view the solution.