We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
nn.with_partitioning Annotation
Why this matters
with_sharding_constraint annotates activations — values inside a
function body. But what about parameters? A 70B-parameter LLM has
each weight matrix sharded a specific way (column-parallel, FSDP’d,
replicated…). You need to attach that sharding info at the place
the parameter is defined, not at every site that touches it.
flax.linen.with_partitioning is that attachment point. It wraps a
parameter’s initializer, and the resulting initialized value
carries Partitioned metadata that downstream tooling — jit,
pjit, Orbax, nnx.bridge — reads to figure out how to lay out the
weight across devices.
The API
import flax.linen as nn
layer = nn.Dense(
features=128,
kernel_init=nn.with_partitioning(
nn.initializers.lecun_normal(),
(None, "model"), # PartitionSpec-like tuple
),
bias_init=nn.with_partitioning(
nn.initializers.zeros_init(),
("model",),
),
)
with_partitioning(init_fn, names) returns a new initializer that:
- Calls the wrapped init to produce the actual array.
-
Wraps the result in a
flax.linen.Partitionedbox that records thenamestuple alongside the array.
The names tuple has one entry per param axis, same convention as
PartitionSpec: None for replicated, "name" for sharded along
that mesh axis.
Reading the metadata back
After params = layer.init(rng, x), the Dense kernel is a
Partitioned pytree node — but for value access it behaves like a
plain array. To pull the spec out:
from flax import linen as nn
from flax.linen import Partitioned
leaf = params["params"]["kernel"]
if isinstance(leaf, Partitioned):
names = leaf.names # (None, "model")
value = leaf.value # the actual array
Tools like nn.get_partition_spec(params) walk the pytree and
produce a parallel pytree of PartitionSpec objects you can feed
into jax.device_put or pjit.
Why annotate at definition time?
Two reasons:
-
Locality. The author of the layer knows how its weights
should be sharded. Embedding that knowledge in the layer means
callers don’t need to plumb a separate sharding tree to every
init. - Initializer participation. Some inits (e.g., Truncated/Glorot) require knowing the full fan-in/fan-out, not the per-shard fan. Wrapping the init lets Flax compute the full value first and shard it after, so the statistics are correct.
Single-device behavior
On a single device with no mesh:
-
nn.with_partitioning(init_fn, names)still returns a valid initializer. -
initruns and returns aPartitioned-wrapped array of the same shape and dtype as the unwrapped init would have produced. - No actual sharding happens. The value is on the one device, and the metadata tags travel along for free.
This is the operating mode for our sandbox: the annotation is inert, but the API surface is identical to the multi-device case.
Common pitfalls
-
Wrong tuple length:
(None, "model")for a 2-D kernel is fine;("model",)for a 2-D kernel raises a sharding rank-mismatch error when the metadata is later consumed. -
Forgetting bias: in column-parallel Dense the bias is also
sharded along the output axis. If you partition the kernel as
(None, "model"), partition the bias as("model",). -
Mixing with bare initializers:
kernel_init = lecun_normal()(no wrap) coexists with a wrapped bias_init — the kernel just won’t have sharding metadata. The compiler will fall back to replicated for that param, which is usually NOT what you want for big weights.
Problem
Build a nn.Dense(features) whose kernel is annotated with
nn.with_partitioning(nn.initializers.lecun_normal(), (None, "model")),
init it on a (seed, x) pair, and return the kernel’s logical shape
as a 1-D (2,) array [in_dim, features] (the partitioning metadata
doesn’t change the shape; it just rides alongside).
For a (N, in_dim) input, the kernel is (in_dim, features), so:
-
in_dim = x.shape[-1] -
out_dim = features
Cast both to float and stack into a 1-D (2,) array.
Inputs:
-
seed: float (cast to int) — PRNG seed. -
x: 2-D(N, in_dim)input. -
features: float (cast to int) — Dense output dimension.
Output: 1-D (2,) array [in_dim, features].
Hints
Sign in to attempt this problem and view the solution.