We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
with_sharding_constraint Annotation
Why this matters
XLA’s auto-parallelization (GSPMD) is propagation-based: you
annotate a few tensors with sharding info, and the compiler
propagates the layout through the rest of the program. The compiler
is good — but not omniscient. Sometimes it picks a bad layout that
triggers expensive resharding mid-computation. The fix is
jax.lax.with_sharding_constraint: a hint that says “this tensor
MUST have this exact sharding from here on — re-layout if needed
to make it so.”
What it actually does
Despite the name, with_sharding_constraint is not a data
transformation. It does not reshape, gather, or scatter. It is a
compile-time annotation recognized by jit / pjit:
import jax
from jax.sharding import NamedSharding, PartitionSpec as P
@jax.jit
def forward(x, params):
x = jax.lax.with_sharding_constraint(
x, NamedSharding(mesh, P("data", None))
)
# x is now guaranteed to be sharded along the "data" mesh axis
# for the rest of this jit'd function. The compiler may insert
# an all-to-all here if x arrived with a different layout.
return params @ x
Outside of jit/pjit, the call passes its input through unchanged
— there’s no compiler watching, so there’s no layout to enforce.
That makes the function safe to put in eager code paths (it’s a no-op
there) and makes single-device test runs work without a real mesh.
When you actually need it
- Activations between layers. The compiler sometimes propagates a layout that’s optimal locally but causes a transpose-heavy reshard between two layers. Pin the layout at the boundary.
-
Forcing FSDP layouts on inputs. Inputs come in unsharded by
default; constrain them to
P("data", None)immediately so the rest of the trace assumes that layout. -
Loss-scaling and reductions. After
psum-style reductions you often want to re-pin to a known layout before the next op. - Debugging. Adding constraints surfaces sharding mismatches as errors rather than silent reshards.
The “no-op in eager” rule
with_sharding_constraint(x, sharding) outside jit has these
semantics:
-
Returns
xunchanged. -
The sharding object is validated (must be a real
Sharding), but not applied. - Useful side effect: you can keep the constraint in your code path and run unit tests on a single device with no mesh.
This is why our sandbox version below just returns the data — the
real production call would be wrapped in a jit over a multi-device
mesh, but the function body is identical.
The full production pattern
devices = mesh_utils.create_device_mesh((8,))
mesh = Mesh(devices, axis_names=("data",))
@jax.jit
def step(x, params):
x = jax.lax.with_sharding_constraint(
x, NamedSharding(mesh, P("data", None))
)
out = params @ x
out = jax.lax.with_sharding_constraint(
out, NamedSharding(mesh, P("data", None))
)
return out
with mesh:
y = step(x_global, params)
Common pitfalls
-
Calling with a
PartitionSpecdirectly instead of a fullSharding:with_sharding_constraint(x, P("data"))raises in modern JAX. Wrap it:NamedSharding(mesh, P("data")). - Expecting a reshape: it does NOT change the logical shape; only the layout across devices.
-
Forgetting to wrap in
jit: outsidejit, the constraint is a no-op, so a multi-device program with nojitkeeps the input layout regardless. -
Constraining intermediate values to invalid layouts: e.g.,
pinning
P("data", "model")when the tensor’s shape isn’t divisible by the mesh axis sizes raises a sharding error at compile time.
Problem
Implement the passthrough semantics that
with_sharding_constraint has outside of jit. Take an input
array x, return it as a flat 1-D array (x.reshape(-1)). On a
single device with no mesh, the constraint is a no-op — the values
pass through unchanged.
The point of this exercise is the description: in real code you’d
wrap this in jit and pass a NamedSharding(mesh, P(...)) as the
second argument, and the compiler would enforce the layout. Here we
practice the call shape — input in, same data out.
Inputs:
-
x: arbitrary-shape float array.
Output: 1-D flat array containing the same values as x.
Hints
Sign in to attempt this problem and view the solution.