hard primitives

pjit Basics

Why this matters

pjit was the original “sharded jit” API in JAX. Modern JAX absorbed it directly into jax.jit via the in_shardings and out_shardings keyword arguments. Understanding pjit means understanding SPMD (Single Program, Multiple Data) execution: the same compiled function runs on every device, but each device works on its own shard of the data.

from jax.sharding import Mesh, NamedSharding, PartitionSpec

devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
sharding = NamedSharding(mesh, PartitionSpec('data'))
f = jax.jit(lambda y: 2 * y, in_shardings=sharding, out_shardings=sharding)
result = f(x)

On a single device, the sharding is a metadata annotation only — the values are unchanged and the computation is equivalent to 2 * x.

Worked mini-example

import jax, numpy as np, jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec

x = jnp.array([1.0, 2.0, 3.0, 4.0])
devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
sharding = NamedSharding(mesh, PartitionSpec('data'))
f = jax.jit(lambda y: 2 * y, in_shardings=sharding, out_shardings=sharding)
print(f(x))  # [2. 4. 6. 8.]

Common pitfalls

  • Shape-shard mismatch: the array’s leading axis must be divisible by the number of devices for PartitionSpec('data'). On 1 device this is always satisfied.
  • in_shardings / out_shardings must be consistent: every positional argument must have a corresponding sharding entry.
  • Old pjit import path: jax.experimental.pjit.pjit is deprecated. Use jax.jit(f, in_shardings=..., out_shardings=...) instead.

Problem

Implement pjit_double(x) that:

  1. Creates a 1-device Mesh with axis name 'data'.
  2. Creates a NamedSharding with PartitionSpec('data') (shard along axis 0).
  3. JIT-compiles lambda y: 2 * y with in_shardings=sharding, out_shardings=sharding.
  4. Returns the result of calling that function on x.
  • x: 1-D JAX array.

Returns: 1-D array, same shape as x, with each element doubled.

Examples (not from the test set):

  • pjit_double(jnp.array([1.0, 2.0]))[2.0, 4.0]

Hints

jax pjit sharding

Sign in to attempt this problem and view the solution.