We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
primitives
with_sharding_constraint
Why this matters
jax.lax.with_sharding_constraint(x, sharding) lets you annotate
intermediate tensors inside a JIT-compiled function to tell XLA how to
shard them. This is the standard pattern for mixed-sharding pipelines where
an intermediate result should be replicated even if the inputs are sharded
(or vice versa).
from jax.sharding import Mesh, NamedSharding, PartitionSpec
devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
sharding = NamedSharding(mesh, PartitionSpec()) # fully replicated
x_c = jax.lax.with_sharding_constraint(x, sharding)
result = jnp.sum(x_c)
Worked mini-example
import jax, numpy as np, jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
x = jnp.array([1.0, 2.0, 3.0])
devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
sharding = NamedSharding(mesh, PartitionSpec())
x_c = jax.lax.with_sharding_constraint(x, sharding)
print(jnp.sum(x_c)) # 6.0
Common pitfalls
-
Constraint outside jit does nothing useful:
with_sharding_constraintis a compiler hint. Outside of a JIT-compiled context it is a no-op for the computation (but does not error). -
PartitionSpec()means fully replicated: passing an emptyPartitionSpec()(no axis names) tells XLA to replicate the tensor across all devices rather than shard it. -
NamedShardingrequires aMesh: build the mesh first, then create theNamedShardingfrom it.
Problem
Implement constrained_sharding(x) that:
-
Creates a 1-device
Meshwith axis name'data'. -
Creates a
NamedShardingwithPartitionSpec()(fully replicated). -
Applies
jax.lax.with_sharding_constraint(x, sharding)tox. -
Returns
jnp.sumof the constrained array.
-
x: 1-D JAX array.
Returns: scalar (sum of x).
Examples (not from the test set):
-
constrained_sharding(jnp.array([1.0, 2.0, 3.0]))→6.0
Hints
jax
with-sharding-constraint
intermediate-sharding
Sign in to attempt this problem and view the solution.