hard primitives

with_sharding_constraint

Why this matters

jax.lax.with_sharding_constraint(x, sharding) lets you annotate intermediate tensors inside a JIT-compiled function to tell XLA how to shard them. This is the standard pattern for mixed-sharding pipelines where an intermediate result should be replicated even if the inputs are sharded (or vice versa).

from jax.sharding import Mesh, NamedSharding, PartitionSpec

devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
sharding = NamedSharding(mesh, PartitionSpec())   # fully replicated

x_c = jax.lax.with_sharding_constraint(x, sharding)
result = jnp.sum(x_c)

Worked mini-example

import jax, numpy as np, jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec

x = jnp.array([1.0, 2.0, 3.0])
devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
sharding = NamedSharding(mesh, PartitionSpec())
x_c = jax.lax.with_sharding_constraint(x, sharding)
print(jnp.sum(x_c))  # 6.0

Common pitfalls

  • Constraint outside jit does nothing useful: with_sharding_constraint is a compiler hint. Outside of a JIT-compiled context it is a no-op for the computation (but does not error).
  • PartitionSpec() means fully replicated: passing an empty PartitionSpec() (no axis names) tells XLA to replicate the tensor across all devices rather than shard it.
  • NamedSharding requires a Mesh: build the mesh first, then create the NamedSharding from it.

Problem

Implement constrained_sharding(x) that:

  1. Creates a 1-device Mesh with axis name 'data'.
  2. Creates a NamedSharding with PartitionSpec() (fully replicated).
  3. Applies jax.lax.with_sharding_constraint(x, sharding) to x.
  4. Returns jnp.sum of the constrained array.
  • x: 1-D JAX array.

Returns: scalar (sum of x).

Examples (not from the test set):

  • constrained_sharding(jnp.array([1.0, 2.0, 3.0]))6.0

Hints

jax with-sharding-constraint intermediate-sharding

Sign in to attempt this problem and view the solution.