hard primitives

NNX Tensor-Parallel Linear

Why this matters

Some layers are too big to fit on a single device — a Llama-70B Linear(8192, 28672) kernel is 940 MB at bf16. Tensor parallelism (TP) carves a single layer’s weight matrix across devices, so each device computes a fraction of the output and concatenates.

Compare to data parallelism, which replicates the whole layer and splits the batch. TP and DP are orthogonal — production transformers typically use both: data parallel across one mesh axis, tensor parallel across another.

Column-parallel Linear

The most common pattern: split the kernel along the output axis. Given kernel: (in_dim, out_dim) and tp_size model-parallel devices,

each device i holds: kernel[:, i*chunk:(i+1)*chunk]   # (in_dim, out_dim/tp)
                      bias[i*chunk:(i+1)*chunk]        # (out_dim/tp,)

Each device computes x @ kernel_i + bias_i to get its slice of the output (shape (N, out_dim/tp)). The full output is the concatenation:

out = concat([x @ kernel_0 + bias_0, ..., x @ kernel_{tp-1} + bias_{tp-1}], axis=-1)

The mesh-axis name is conventionally "model". The PartitionSpec is P(None, "model") for the kernel and P("model") for the bias.

Why x doesn’t shard

x: (N, in_dim) is replicated across the model-parallel devices (it’s already broadcast to all of them). What’s sharded is the kernel, not the input. After the matmul, each device has a different slice of the output — concat reconstructs the full (N, out_dim).

Contrast with row-parallel Linear (P("model", None)): the kernel splits along in_dim, the input is split along the same axis, each device produces a partial sum of the output, and a final psum all-reduce assembles the full result.

Why “feed” the model normally first

A real TP setup constructs the kernel directly in the sharded layout — only the local slice is materialized on each device. In our single-device sandbox we build a normal nnx.Linear(in_dim, feats), then manually slice its kernel in the loop. The full kernel exists in memory; we ignore that and only use slices, to mimic what each device sees in production.

model = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed))
kernel = model.kernel.value      # (in, out)
bias = model.bias.value          # (out,)
chunk = features // tp_size
outs = []
for i in range(tp_size):
    k = kernel[:, i*chunk:(i+1)*chunk]
    b = bias[i*chunk:(i+1)*chunk]
    outs.append(x @ k + b)
out = jnp.concatenate(outs, axis=-1)

The result is identical to model(x) — that’s the correctness invariant of column-parallel TP: concat([x @ K_i + b_i]) == x @ K + b.

Why call into raw arrays

model.kernel.value and model.bias.value unwrap the nnx.Param Variables to plain jax.numpy arrays. From there, jax.lax.* works directly: dynamic_slice_in_dim, matmul, etc. NNX modules are mostly transparent at the array layer — the wrapping exists for graph traversal, not for compute.

Common pitfalls

  • Slicing along axis 0 instead of axis 1: that’s row-parallel, and it requires a different reduce (psum of partial outputs) to assemble correctly.
  • Forgetting the bias split: bias is shape (out_dim,) — shard along axis 0 so each device’s b_i matches its k_i.
  • Concat along the wrong axis: each device’s output is (N, chunk) — concat along axis=-1 (or axis=1), not along the batch axis.
  • Asking for tp_size not dividing out_dim: this problem assumes even division. Production handles padding.

Problem

Implement tensor_parallel_linear(seed, x, features, tp_size):

  1. Cast seed, features, tp_size to ints. Use x.shape[-1] as the input dim.
  2. Build model = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed)).
  3. kernel = model.kernel.value; bias = model.bias.value.
  4. chunk = features // tp_size.
  5. Loop i in range(tp_size): take kernel[:, i*chunk:(i+1)*chunk] and bias[i*chunk:(i+1)*chunk]; compute x @ k + b; collect.
  6. out = jnp.concatenate(outs, axis=-1) — shape (N, features).
  7. Return out.reshape(-1) flattened.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, in_dim).
  • features: float (cast to int) — Linear out dim, divisible by tp_size.
  • tp_size: float (cast to int) — model-parallel size.

Output: 1-D, length N * features.

Hints

flax nnx sharding tensor-parallel linear

Sign in to attempt this problem and view the solution.