hard primitives

Mixture-of-Experts FFN

Why this matters

A standard Transformer block runs every input through one (D → 4D → D) FFN. Doubling FFN size doubles FLOPs PER TOKEN — model and compute scale together.

Mixture-of-Experts (MoE) is the trick that breaks this lock. Instead of one giant FFN, you have N smaller “expert” FFNs. For each token, a tiny router picks just k of them (typically k=1 or k=2). The token goes through ONLY those k experts, and their outputs are weighted by the router’s score.

Result: you can have a model with 100x more parameters but only k/N of the FLOPs at inference. Switch Transformer (1.6T params, 2021), Mixtral (47B / 13B active per token, 2023), and Google’s GLaM all use this pattern.

The pieces

Given input x of shape (T, D):

  1. Experts: N parallel FFNs, each Dense(d_ff) → relu → Dense(D).
  2. Router: a single Dense(num_experts) followed by softmax, producing per-token expert probabilities (T, N).
  3. Top-k selection: pick the k experts with highest probs per token. Renormalize their probs to sum to 1.
  4. Sparse compute (logically): each token runs through only those k experts, weighted by renormalized probs.

Top-k in JAX

jax.lax.top_k(values, k) returns (top_vals, top_idx), both shape (..., k):

router_probs = jax.nn.softmax(router_logits, axis=-1)   # (T, N)
topk_vals, topk_idx = jax.lax.top_k(router_probs, k)    # (T, k), (T, k)
topk_w = topk_vals / topk_vals.sum(axis=-1, keepdims=True)   # renorm

The renormalization is important: after picking k experts, their probs don’t sum to 1 anymore. Renormalize so the weighted combination is still a proper convex combination.

Sparse vs dense compute

Production MoE implementations carefully shuffle tokens so each expert only computes its assigned subset (true sparsity). For a short tutorial, that’s overkill — instead, COMPUTE every expert on every token, then GATHER the top-k:

expert_outs = jnp.stack([Expert(d_ff)(x) for _ in range(N)], axis=1)
# shape: (T, N, D) — every token through every expert

Then index expert_outs with topk_idx:

gathered = jnp.take_along_axis(expert_outs, topk_idx[:, :, None], axis=1)
# shape: (T, k, D)
out = jnp.sum(gathered * topk_w[:, :, None], axis=1)
# shape: (T, D)

This is mathematically equivalent and fits cleanly in @nn.compact. The “wasted” compute on non-top-k experts is fine for tutorial scale.

Worked walk-through

T=2, D=4, N=4, d_ff=8, k=2:

  1. Build 4 expert FFNs. Compute all on x: (2, 4, 4) (T, N, D).
  2. Router: Dense(4)(x) → softmax. → (2, 4).
  3. top_k(probs, 2) → vals (2, 2), idx (2, 2).
  4. Renormalize vals along last axis.
  5. Gather expert outputs by idx: (2, 2, 4).
  6. Multiply by weights (2, 2, 1), sum over expert axis: (2, 4).
  7. Flatten: (8,).

Common pitfalls

  • Forgetting to renormalize the top-k probs: leaves the effective weight far from 1, scaling the output down. topk_w = topk_vals / topk_vals.sum(axis=-1, keepdims=True).
  • Sharing one Expert across all routes: each expert needs INDEPENDENT params. Build N separate Expert instances inside the @nn.compact body (a list comprehension is fine).
  • Wrong axis on take_along_axis: experts are stacked on axis 1 (after T), so gather along axis 1.
  • Skipping the softmax: gating with raw logits doesn’t form probabilities; the weighted combination is no longer convex.

Problem

Implement moe_ffn_forward(seed, x, num_experts, d_ff, top_k):

  1. Expert(nn.Module) with field d_ff: Dense(d_ff) → relu → Dense(D) where D = x.shape[-1].
  2. MoEFFN(nn.Module) with num_experts, d_ff, top_k fields.
  3. Inside: build num_experts Experts, run all on x, stack on a new expert axis.
  4. Router: Dense(num_experts) + softmax. Top-k + renormalize.
  5. Gather expert outputs by top-k idx, weight, sum.
  6. Return flattened.

Inputs:

  • seed: int.
  • x: 2-D (T, D).
  • num_experts: int N.
  • d_ff: int (FFN hidden dim).
  • top_k: int (experts per token).

Output: 1-D, length T * D.

Hints

flax moe router top-k

Sign in to attempt this problem and view the solution.