We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Mesh Init Simulation
Why this matters
A 70B-parameter model can’t be initialized on a single host. Even if the params would fit in HBM, the initial allocation must happen into the sharded layout — there’s no point materializing the full kernel on one device and then sharding, because step one OOMs.
The production pattern is to init under a Mesh, with the
correct out_shardings declared up front, so JAX places each
parameter directly on the device that owns it:
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((mesh_axes,))
mesh = Mesh(devices, axis_names=("model",))
@jax.jit
def create_model(seed):
return nnx.Linear(in_dim, out_dim, rngs=nnx.Rngs(seed))
with mesh:
# Specify how the OUTPUT model's params should be sharded.
# nnx integrates with `out_shardings` via state-level specs.
graphdef, state_spec = ...
sharded_state = jax.jit(
lambda s: create_model(s).split()[1],
out_shardings=state_spec,
)(seed)
model = nnx.merge(graphdef, sharded_state)
The kernel never exists in full anywhere. Each device only ever holds its slice.
Why we simulate here
A real Mesh((mesh_axes,)) requires mesh_axes physical devices.
The sandbox has one. So we simulate the control flow: build the
model that would normally be built under with mesh:, and report
the configuration knobs (features, mesh_axes, n_leaves).
The educational point isn’t the mesh object itself — it’s that an
nnx Module constructed under with mesh: looks identical to one
built without; the difference is where its arrays live. So our
in-sandbox nnx.Linear(...) IS the model the production code
would build. We just don’t actually shard.
Anatomy of a mesh
mesh = Mesh(devices, axis_names=("data", "model"))
-
devices: annp.ndarrayofDeviceobjects, with shape(data_axis_size, model_axis_size). The product equals the total device count. -
axis_names: a tuple of names — one per axis of the device array. These are the strings that go intoPartitionSpec.
A 1-D mesh Mesh(devices, ("model",)) is the simplest case:
everything is sharded along one model-parallel axis. A 2-D mesh
("data", "model") lets you do FSDP+TP simultaneously.
Why the with mesh: block matters
Inside the block, JAX uses the mesh as the default for any
NamedSharding(...) created without an explicit mesh= arg. The
nnx.Rngs constructor and the Linear constructor are pure
Python — they don’t care about the mesh — but the arrays they
allocate inherit the layout the surrounding jit+out_shardings
declared.
Common pitfalls
- Initing without a mesh and then sharding: the un-sharded tensor briefly exists in full on one device. For huge models this OOMs immediately.
-
Building the mesh inside the function body:
Mesh(...)is a Python value; build it once at program start. -
Forgetting
out_shardings: just being insidewith mesh:isn’t enough — JIT also needs to know how outputs should be laid out.
Problem
Implement mesh_init_simulate(seed, x, features, mesh_axes):
-
Cast
seed,features,mesh_axesto ints. -
Build
model = nnx.Linear(x.shape[-1], features, rngs=nnx.Rngs(seed))— this stands in for the productionwith mesh: model = ...call (the surrounding mesh is what differs, not the build line). -
Confirm the model has the expected number of state leaves with
_, state = nnx.split(model);n_leaves = len(tree_leaves(state)). -
Return
jnp.array([float(features), float(mesh_axes), float(n_leaves)])as a 1-D(3,)array.
For an nnx.Linear, n_leaves is 2 (kernel + bias).
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, in_dim). -
features: float (cast to int) — Linear out dim. -
mesh_axes: float (cast to int) — size of the simulated 1-D mesh.
Output: 1-D (3,) — [features, mesh_axes, n_leaves].
Hints
Sign in to attempt this problem and view the solution.