medium primitives

nn.with_partitioning Annotation

Why this matters

with_sharding_constraint annotates activations — values inside a function body. But what about parameters? A 70B-parameter LLM has each weight matrix sharded a specific way (column-parallel, FSDP’d, replicated…). You need to attach that sharding info at the place the parameter is defined, not at every site that touches it.

flax.linen.with_partitioning is that attachment point. It wraps a parameter’s initializer, and the resulting initialized value carries Partitioned metadata that downstream tooling — jit, pjit, Orbax, nnx.bridge — reads to figure out how to lay out the weight across devices.

The API

import flax.linen as nn

layer = nn.Dense(
    features=128,
    kernel_init=nn.with_partitioning(
        nn.initializers.lecun_normal(),
        (None, "model"),       # PartitionSpec-like tuple
    ),
    bias_init=nn.with_partitioning(
        nn.initializers.zeros_init(),
        ("model",),
    ),
)

with_partitioning(init_fn, names) returns a new initializer that:

  1. Calls the wrapped init to produce the actual array.
  2. Wraps the result in a flax.linen.Partitioned box that records the names tuple alongside the array.

The names tuple has one entry per param axis, same convention as PartitionSpec: None for replicated, "name" for sharded along that mesh axis.

Reading the metadata back

After params = layer.init(rng, x), the Dense kernel is a Partitioned pytree node — but for value access it behaves like a plain array. To pull the spec out:

from flax import linen as nn
from flax.linen import Partitioned

leaf = params["params"]["kernel"]
if isinstance(leaf, Partitioned):
    names = leaf.names         # (None, "model")
    value = leaf.value         # the actual array

Tools like nn.get_partition_spec(params) walk the pytree and produce a parallel pytree of PartitionSpec objects you can feed into jax.device_put or pjit.

Why annotate at definition time?

Two reasons:

  1. Locality. The author of the layer knows how its weights should be sharded. Embedding that knowledge in the layer means callers don’t need to plumb a separate sharding tree to every init.
  2. Initializer participation. Some inits (e.g., Truncated/Glorot) require knowing the full fan-in/fan-out, not the per-shard fan. Wrapping the init lets Flax compute the full value first and shard it after, so the statistics are correct.

Single-device behavior

On a single device with no mesh:

  • nn.with_partitioning(init_fn, names) still returns a valid initializer.
  • init runs and returns a Partitioned-wrapped array of the same shape and dtype as the unwrapped init would have produced.
  • No actual sharding happens. The value is on the one device, and the metadata tags travel along for free.

This is the operating mode for our sandbox: the annotation is inert, but the API surface is identical to the multi-device case.

Common pitfalls

  • Wrong tuple length: (None, "model") for a 2-D kernel is fine; ("model",) for a 2-D kernel raises a sharding rank-mismatch error when the metadata is later consumed.
  • Forgetting bias: in column-parallel Dense the bias is also sharded along the output axis. If you partition the kernel as (None, "model"), partition the bias as ("model",).
  • Mixing with bare initializers: kernel_init = lecun_normal() (no wrap) coexists with a wrapped bias_init — the kernel just won’t have sharding metadata. The compiler will fall back to replicated for that param, which is usually NOT what you want for big weights.

Problem

Build a nn.Dense(features) whose kernel is annotated with nn.with_partitioning(nn.initializers.lecun_normal(), (None, "model")), init it on a (seed, x) pair, and return the kernel’s logical shape as a 1-D (2,) array [in_dim, features] (the partitioning metadata doesn’t change the shape; it just rides alongside).

For a (N, in_dim) input, the kernel is (in_dim, features), so:

  • in_dim = x.shape[-1]
  • out_dim = features

Cast both to float and stack into a 1-D (2,) array.

Inputs:

  • seed: float (cast to int) — PRNG seed.
  • x: 2-D (N, in_dim) input.
  • features: float (cast to int) — Dense output dimension.

Output: 1-D (2,) array [in_dim, features].

Hints

flax sharding partitioning

Sign in to attempt this problem and view the solution.