We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Tensor-Parallel Linear
Why this matters
Some layers are too big to fit on a single device — a Llama-70B
Linear(8192, 28672) kernel is 940 MB at bf16. Tensor parallelism
(TP) carves a single layer’s weight matrix across devices, so
each device computes a fraction of the output and concatenates.
Compare to data parallelism, which replicates the whole layer and splits the batch. TP and DP are orthogonal — production transformers typically use both: data parallel across one mesh axis, tensor parallel across another.
Column-parallel Linear
The most common pattern: split the kernel along the output axis.
Given kernel: (in_dim, out_dim) and tp_size model-parallel
devices,
each device i holds: kernel[:, i*chunk:(i+1)*chunk] # (in_dim, out_dim/tp)
bias[i*chunk:(i+1)*chunk] # (out_dim/tp,)
Each device computes x @ kernel_i + bias_i to get its slice of
the output (shape (N, out_dim/tp)). The full output is the
concatenation:
out = concat([x @ kernel_0 + bias_0, ..., x @ kernel_{tp-1} + bias_{tp-1}], axis=-1)
The mesh-axis name is conventionally "model". The PartitionSpec
is P(None, "model") for the kernel and P("model") for the
bias.
Why x doesn’t shard
x: (N, in_dim) is replicated across the model-parallel devices
(it’s already broadcast to all of them). What’s sharded is the
kernel, not the input. After the matmul, each device has a
different slice of the output — concat reconstructs the full
(N, out_dim).
Contrast with row-parallel Linear (P("model", None)): the
kernel splits along in_dim, the input is split along the same
axis, each device produces a partial sum of the output, and a
final psum all-reduce assembles the full result.
Why “feed” the model normally first
A real TP setup constructs the kernel directly in the sharded
layout — only the local slice is materialized on each device. In
our single-device sandbox we build a normal nnx.Linear(in_dim, feats), then manually slice its kernel in the loop. The full
kernel exists in memory; we ignore that and only use slices, to
mimic what each device sees in production.
model = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed))
kernel = model.kernel.value # (in, out)
bias = model.bias.value # (out,)
chunk = features // tp_size
outs = []
for i in range(tp_size):
k = kernel[:, i*chunk:(i+1)*chunk]
b = bias[i*chunk:(i+1)*chunk]
outs.append(x @ k + b)
out = jnp.concatenate(outs, axis=-1)
The result is identical to model(x) — that’s the correctness
invariant of column-parallel TP: concat([x @ K_i + b_i]) == x @ K + b.
Why call into raw arrays
model.kernel.value and model.bias.value unwrap the nnx.Param
Variables to plain jax.numpy arrays. From there, jax.lax.*
works directly: dynamic_slice_in_dim, matmul, etc. NNX modules
are mostly transparent at the array layer — the wrapping exists
for graph traversal, not for compute.
Common pitfalls
-
Slicing along axis 0 instead of axis 1: that’s row-parallel,
and it requires a different reduce (
psumof partial outputs) to assemble correctly. -
Forgetting the bias split: bias is shape
(out_dim,)— shard along axis 0 so each device’sb_imatches itsk_i. -
Concat along the wrong axis: each device’s output is
(N, chunk)— concat alongaxis=-1(oraxis=1), not along the batch axis. -
Asking for
tp_sizenot dividingout_dim: this problem assumes even division. Production handles padding.
Problem
Implement tensor_parallel_linear(seed, x, features, tp_size):
-
Cast
seed,features,tp_sizeto ints. Usex.shape[-1]as the input dim. -
Build
model = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed)). -
kernel = model.kernel.value;bias = model.bias.value. -
chunk = features // tp_size. -
Loop
i in range(tp_size): takekernel[:, i*chunk:(i+1)*chunk]andbias[i*chunk:(i+1)*chunk]; computex @ k + b; collect. -
out = jnp.concatenate(outs, axis=-1)— shape(N, features). -
Return
out.reshape(-1)flattened.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, in_dim). -
features: float (cast to int) — Linear out dim, divisible bytp_size. -
tp_size: float (cast to int) — model-parallel size.
Output: 1-D, length N * features.
Hints
Sign in to attempt this problem and view the solution.