We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NamedSharding (Replicated)
Why this matters
Not every tensor should be sharded. Small constants — optimizer hyperparams, layer norms’ scale/bias, positional encodings — are cheapest to replicate across every device. Every device then holds a full copy, avoiding any cross-device communication when the tensor is read.
An empty PartitionSpec() (no arguments) expresses full replication:
sharding = NamedSharding(mesh, PartitionSpec())
x_rep = jax.device_put(x, sharding)
# Every device holds all of x.
On 1 device, replication is a no-op — the values are unchanged. On many
devices, replication ensures every rank has the same copy without needing
an all_gather.
Worked mini-example
import jax, numpy as np, jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
sharding = NamedSharding(mesh, PartitionSpec()) # fully replicated
bias = jnp.array([0.1, -0.2, 0.3])
bias_rep = jax.device_put(bias, sharding)
print(jnp.sum(bias_rep)) # 0.2
Common pitfalls
-
Don’t replicate large parameters: replication defeats the purpose of
distributed training for tensors that consume significant memory. Reserve
PartitionSpec()for truly small constants. -
PartitionSpec()vsPartitionSpec(None): both replicate, butNonein a 1-element spec explicitly covers a 1-D array. For a shape-agnostic “replicate everything” pattern, the empty specPartitionSpec()is simpler. -
Axis names are irrelevant for fully replicated arrays: the mesh is still
required by
NamedSharding, even though no axis is used in the spec.
Problem
Implement replicated_sharding_value(x) that:
-
Creates a 1-device
Meshwith axis name'data'. -
Builds
NamedSharding(mesh, PartitionSpec())— fully replicated. -
Places
xon the sharding withjax.device_put. -
Returns
jnp.sum(x_sharded).
Single-device caveat: the test runner has exactly 1 CPU device. On 1 device the replicated sharding is a metadata annotation only — values are unchanged. The same API call on a multi-device host would broadcast the full array to every device before summing.
-
x: 1-D JAX array.
Returns: scalar — jnp.sum(x).
Examples (not from the test set):
-
replicated_sharding_value(jnp.array([1.0, 2.0, 3.0]))→6.0
Hints
Sign in to attempt this problem and view the solution.