medium primitives

NamedSharding (Replicated)

Why this matters

Not every tensor should be sharded. Small constants — optimizer hyperparams, layer norms’ scale/bias, positional encodings — are cheapest to replicate across every device. Every device then holds a full copy, avoiding any cross-device communication when the tensor is read.

An empty PartitionSpec() (no arguments) expresses full replication:

sharding = NamedSharding(mesh, PartitionSpec())
x_rep = jax.device_put(x, sharding)
# Every device holds all of x.

On 1 device, replication is a no-op — the values are unchanged. On many devices, replication ensures every rank has the same copy without needing an all_gather.

Worked mini-example

import jax, numpy as np, jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec

devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
sharding = NamedSharding(mesh, PartitionSpec())   # fully replicated

bias = jnp.array([0.1, -0.2, 0.3])
bias_rep = jax.device_put(bias, sharding)
print(jnp.sum(bias_rep))   # 0.2

Common pitfalls

  • Don’t replicate large parameters: replication defeats the purpose of distributed training for tensors that consume significant memory. Reserve PartitionSpec() for truly small constants.
  • PartitionSpec() vs PartitionSpec(None): both replicate, but None in a 1-element spec explicitly covers a 1-D array. For a shape-agnostic “replicate everything” pattern, the empty spec PartitionSpec() is simpler.
  • Axis names are irrelevant for fully replicated arrays: the mesh is still required by NamedSharding, even though no axis is used in the spec.

Problem

Implement replicated_sharding_value(x) that:

  1. Creates a 1-device Mesh with axis name 'data'.
  2. Builds NamedSharding(mesh, PartitionSpec()) — fully replicated.
  3. Places x on the sharding with jax.device_put.
  4. Returns jnp.sum(x_sharded).

Single-device caveat: the test runner has exactly 1 CPU device. On 1 device the replicated sharding is a metadata annotation only — values are unchanged. The same API call on a multi-device host would broadcast the full array to every device before summing.

  • x: 1-D JAX array.

Returns: scalar — jnp.sum(x).

Examples (not from the test set):

  • replicated_sharding_value(jnp.array([1.0, 2.0, 3.0]))6.0

Hints

jax named-sharding device-put

Sign in to attempt this problem and view the solution.