We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Mixture-of-Experts FFN
Why this matters
A standard Transformer block runs every input through one
(D → 4D → D) FFN. Doubling FFN size doubles FLOPs PER TOKEN
— model and compute scale together.
Mixture-of-Experts (MoE) is the trick that breaks this lock.
Instead of one giant FFN, you have N smaller “expert” FFNs.
For each token, a tiny router picks just k of them
(typically k=1 or k=2). The token goes through ONLY those
k experts, and their outputs are weighted by the router’s
score.
Result: you can have a model with 100x more parameters but
only k/N of the FLOPs at inference. Switch Transformer (1.6T
params, 2021), Mixtral (47B / 13B active per token, 2023), and
Google’s GLaM all use this pattern.
The pieces
Given input x of shape (T, D):
-
Experts:
Nparallel FFNs, eachDense(d_ff) → relu → Dense(D). -
Router: a single
Dense(num_experts)followed by softmax, producing per-token expert probabilities(T, N). -
Top-k selection: pick the
kexperts with highest probs per token. Renormalize their probs to sum to 1. -
Sparse compute (logically): each token runs through only
those
kexperts, weighted by renormalized probs.
Top-k in JAX
jax.lax.top_k(values, k) returns (top_vals, top_idx), both
shape (..., k):
router_probs = jax.nn.softmax(router_logits, axis=-1) # (T, N)
topk_vals, topk_idx = jax.lax.top_k(router_probs, k) # (T, k), (T, k)
topk_w = topk_vals / topk_vals.sum(axis=-1, keepdims=True) # renorm
The renormalization is important: after picking k experts,
their probs don’t sum to 1 anymore. Renormalize so the weighted
combination is still a proper convex combination.
Sparse vs dense compute
Production MoE implementations carefully shuffle tokens so each expert only computes its assigned subset (true sparsity). For a short tutorial, that’s overkill — instead, COMPUTE every expert on every token, then GATHER the top-k:
expert_outs = jnp.stack([Expert(d_ff)(x) for _ in range(N)], axis=1)
# shape: (T, N, D) — every token through every expert
Then index expert_outs with topk_idx:
gathered = jnp.take_along_axis(expert_outs, topk_idx[:, :, None], axis=1)
# shape: (T, k, D)
out = jnp.sum(gathered * topk_w[:, :, None], axis=1)
# shape: (T, D)
This is mathematically equivalent and fits cleanly in @nn.compact.
The “wasted” compute on non-top-k experts is fine for tutorial
scale.
Worked walk-through
T=2, D=4, N=4, d_ff=8, k=2:
-
Build 4 expert FFNs. Compute all on
x:(2, 4, 4)(T, N, D). -
Router:
Dense(4)(x) → softmax. →(2, 4). -
top_k(probs, 2)→ vals(2, 2), idx(2, 2). - Renormalize vals along last axis.
-
Gather expert outputs by idx:
(2, 2, 4). -
Multiply by weights
(2, 2, 1), sum over expert axis:(2, 4). -
Flatten:
(8,).
Common pitfalls
-
Forgetting to renormalize the top-k probs: leaves the
effective weight far from 1, scaling the output down.
topk_w = topk_vals / topk_vals.sum(axis=-1, keepdims=True). -
Sharing one Expert across all routes: each expert needs
INDEPENDENT params. Build
NseparateExpertinstances inside the@nn.compactbody (a list comprehension is fine). -
Wrong axis on
take_along_axis: experts are stacked on axis 1 (afterT), so gather along axis 1. - Skipping the softmax: gating with raw logits doesn’t form probabilities; the weighted combination is no longer convex.
Problem
Implement moe_ffn_forward(seed, x, num_experts, d_ff, top_k):
-
Expert(nn.Module)with fieldd_ff:Dense(d_ff) → relu → Dense(D)whereD = x.shape[-1]. -
MoEFFN(nn.Module)withnum_experts, d_ff, top_kfields. -
Inside: build
num_expertsExperts, run all onx, stack on a new expert axis. -
Router:
Dense(num_experts)+ softmax. Top-k + renormalize. - Gather expert outputs by top-k idx, weight, sum.
- Return flattened.
Inputs:
-
seed: int. -
x: 2-D(T, D). -
num_experts: int N. -
d_ff: int (FFN hidden dim). -
top_k: int (experts per token).
Output: 1-D, length T * D.
Hints
Sign in to attempt this problem and view the solution.