We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
pjit Basics
Why this matters
pjit was the original “sharded jit” API in JAX. Modern JAX absorbed it
directly into jax.jit via the in_shardings and out_shardings keyword
arguments. Understanding pjit means understanding SPMD (Single Program,
Multiple Data) execution: the same compiled function runs on every device,
but each device works on its own shard of the data.
from jax.sharding import Mesh, NamedSharding, PartitionSpec
devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
sharding = NamedSharding(mesh, PartitionSpec('data'))
f = jax.jit(lambda y: 2 * y, in_shardings=sharding, out_shardings=sharding)
result = f(x)
On a single device, the sharding is a metadata annotation only — the values
are unchanged and the computation is equivalent to 2 * x.
Worked mini-example
import jax, numpy as np, jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
x = jnp.array([1.0, 2.0, 3.0, 4.0])
devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
sharding = NamedSharding(mesh, PartitionSpec('data'))
f = jax.jit(lambda y: 2 * y, in_shardings=sharding, out_shardings=sharding)
print(f(x)) # [2. 4. 6. 8.]
Common pitfalls
-
Shape-shard mismatch: the array’s leading axis must be divisible by the
number of devices for
PartitionSpec('data'). On 1 device this is always satisfied. -
in_shardings/out_shardingsmust be consistent: every positional argument must have a corresponding sharding entry. -
Old
pjitimport path:jax.experimental.pjit.pjitis deprecated. Usejax.jit(f, in_shardings=..., out_shardings=...)instead.
Problem
Implement pjit_double(x) that:
-
Creates a 1-device
Meshwith axis name'data'. -
Creates a
NamedShardingwithPartitionSpec('data')(shard along axis 0). -
JIT-compiles
lambda y: 2 * ywithin_shardings=sharding, out_shardings=sharding. -
Returns the result of calling that function on
x.
-
x: 1-D JAX array.
Returns: 1-D array, same shape as x, with each element doubled.
Examples (not from the test set):
-
pjit_double(jnp.array([1.0, 2.0]))→[2.0, 4.0]
Hints
Sign in to attempt this problem and view the solution.