medium primitives

with_sharding_constraint Annotation

Why this matters

XLA’s auto-parallelization (GSPMD) is propagation-based: you annotate a few tensors with sharding info, and the compiler propagates the layout through the rest of the program. The compiler is good — but not omniscient. Sometimes it picks a bad layout that triggers expensive resharding mid-computation. The fix is jax.lax.with_sharding_constraint: a hint that says “this tensor MUST have this exact sharding from here on — re-layout if needed to make it so.”

What it actually does

Despite the name, with_sharding_constraint is not a data transformation. It does not reshape, gather, or scatter. It is a compile-time annotation recognized by jit / pjit:

import jax
from jax.sharding import NamedSharding, PartitionSpec as P

@jax.jit
def forward(x, params):
    x = jax.lax.with_sharding_constraint(
        x, NamedSharding(mesh, P("data", None))
    )
    # x is now guaranteed to be sharded along the "data" mesh axis
    # for the rest of this jit'd function. The compiler may insert
    # an all-to-all here if x arrived with a different layout.
    return params @ x

Outside of jit/pjit, the call passes its input through unchanged — there’s no compiler watching, so there’s no layout to enforce. That makes the function safe to put in eager code paths (it’s a no-op there) and makes single-device test runs work without a real mesh.

When you actually need it

  1. Activations between layers. The compiler sometimes propagates a layout that’s optimal locally but causes a transpose-heavy reshard between two layers. Pin the layout at the boundary.
  2. Forcing FSDP layouts on inputs. Inputs come in unsharded by default; constrain them to P("data", None) immediately so the rest of the trace assumes that layout.
  3. Loss-scaling and reductions. After psum-style reductions you often want to re-pin to a known layout before the next op.
  4. Debugging. Adding constraints surfaces sharding mismatches as errors rather than silent reshards.

The “no-op in eager” rule

with_sharding_constraint(x, sharding) outside jit has these semantics:

  • Returns x unchanged.
  • The sharding object is validated (must be a real Sharding), but not applied.
  • Useful side effect: you can keep the constraint in your code path and run unit tests on a single device with no mesh.

This is why our sandbox version below just returns the data — the real production call would be wrapped in a jit over a multi-device mesh, but the function body is identical.

The full production pattern

devices = mesh_utils.create_device_mesh((8,))
mesh = Mesh(devices, axis_names=("data",))

@jax.jit
def step(x, params):
    x = jax.lax.with_sharding_constraint(
        x, NamedSharding(mesh, P("data", None))
    )
    out = params @ x
    out = jax.lax.with_sharding_constraint(
        out, NamedSharding(mesh, P("data", None))
    )
    return out

with mesh:
    y = step(x_global, params)

Common pitfalls

  • Calling with a PartitionSpec directly instead of a full Sharding: with_sharding_constraint(x, P("data")) raises in modern JAX. Wrap it: NamedSharding(mesh, P("data")).
  • Expecting a reshape: it does NOT change the logical shape; only the layout across devices.
  • Forgetting to wrap in jit: outside jit, the constraint is a no-op, so a multi-device program with no jit keeps the input layout regardless.
  • Constraining intermediate values to invalid layouts: e.g., pinning P("data", "model") when the tensor’s shape isn’t divisible by the mesh axis sizes raises a sharding error at compile time.

Problem

Implement the passthrough semantics that with_sharding_constraint has outside of jit. Take an input array x, return it as a flat 1-D array (x.reshape(-1)). On a single device with no mesh, the constraint is a no-op — the values pass through unchanged.

The point of this exercise is the description: in real code you’d wrap this in jit and pass a NamedSharding(mesh, P(...)) as the second argument, and the compiler would enforce the layout. Here we practice the call shape — input in, same data out.

Inputs:

  • x: arbitrary-shape float array.

Output: 1-D flat array containing the same values as x.

Hints

flax sharding jit

Sign in to attempt this problem and view the solution.