medium primitives

NNX Mesh Init Simulation

Why this matters

A 70B-parameter model can’t be initialized on a single host. Even if the params would fit in HBM, the initial allocation must happen into the sharded layout — there’s no point materializing the full kernel on one device and then sharding, because step one OOMs.

The production pattern is to init under a Mesh, with the correct out_shardings declared up front, so JAX places each parameter directly on the device that owns it:

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import mesh_utils

devices = mesh_utils.create_device_mesh((mesh_axes,))
mesh = Mesh(devices, axis_names=("model",))

@jax.jit
def create_model(seed):
    return nnx.Linear(in_dim, out_dim, rngs=nnx.Rngs(seed))

with mesh:
    # Specify how the OUTPUT model's params should be sharded.
    # nnx integrates with `out_shardings` via state-level specs.
    graphdef, state_spec = ...
    sharded_state = jax.jit(
        lambda s: create_model(s).split()[1],
        out_shardings=state_spec,
    )(seed)
    model = nnx.merge(graphdef, sharded_state)

The kernel never exists in full anywhere. Each device only ever holds its slice.

Why we simulate here

A real Mesh((mesh_axes,)) requires mesh_axes physical devices. The sandbox has one. So we simulate the control flow: build the model that would normally be built under with mesh:, and report the configuration knobs (features, mesh_axes, n_leaves).

The educational point isn’t the mesh object itself — it’s that an nnx Module constructed under with mesh: looks identical to one built without; the difference is where its arrays live. So our in-sandbox nnx.Linear(...) IS the model the production code would build. We just don’t actually shard.

Anatomy of a mesh

mesh = Mesh(devices, axis_names=("data", "model"))
  • devices: an np.ndarray of Device objects, with shape (data_axis_size, model_axis_size). The product equals the total device count.
  • axis_names: a tuple of names — one per axis of the device array. These are the strings that go into PartitionSpec.

A 1-D mesh Mesh(devices, ("model",)) is the simplest case: everything is sharded along one model-parallel axis. A 2-D mesh ("data", "model") lets you do FSDP+TP simultaneously.

Why the with mesh: block matters

Inside the block, JAX uses the mesh as the default for any NamedSharding(...) created without an explicit mesh= arg. The nnx.Rngs constructor and the Linear constructor are pure Python — they don’t care about the mesh — but the arrays they allocate inherit the layout the surrounding jit+out_shardings declared.

Common pitfalls

  • Initing without a mesh and then sharding: the un-sharded tensor briefly exists in full on one device. For huge models this OOMs immediately.
  • Building the mesh inside the function body: Mesh(...) is a Python value; build it once at program start.
  • Forgetting out_shardings: just being inside with mesh: isn’t enough — JIT also needs to know how outputs should be laid out.

Problem

Implement mesh_init_simulate(seed, x, features, mesh_axes):

  1. Cast seed, features, mesh_axes to ints.
  2. Build model = nnx.Linear(x.shape[-1], features, rngs=nnx.Rngs(seed)) — this stands in for the production with mesh: model = ... call (the surrounding mesh is what differs, not the build line).
  3. Confirm the model has the expected number of state leaves with _, state = nnx.split(model); n_leaves = len(tree_leaves(state)).
  4. Return jnp.array([float(features), float(mesh_axes), float(n_leaves)]) as a 1-D (3,) array.

For an nnx.Linear, n_leaves is 2 (kernel + bias).

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, in_dim).
  • features: float (cast to int) — Linear out dim.
  • mesh_axes: float (cast to int) — size of the simulated 1-D mesh.

Output: 1-D (3,)[features, mesh_axes, n_leaves].

Hints

flax nnx sharding mesh init

Sign in to attempt this problem and view the solution.