We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX SwiGLU FFN
Why this matters
In the original Transformer (and the encoder/decoder blocks in
pos 41/42), the position-wise FFN is Dense -> ReLU -> Dense. Modern
LLMs — LLaMA, PaLM, Mistral, DeepSeek — replace this with SwiGLU:
a GATED feed-forward block where one branch modulates the other
via element-wise multiplication.
SwiGLU costs ~50% more parameters and FLOPs than the vanilla FFN, but it consistently improves perplexity. Once you know how it works, swapping it into a transformer block is one line of code — that’s why it’s everywhere now.
The formula
For input x of shape (T, d_model):
SwiGLU(x) = down( silu(gate(x)) * up(x) )
where:
-
gate,up: eachnnx.Linear(d_model, d_ff). Two separate projections to the wider hidden width. -
silu: a.k.a. the swish nonlinearity,silu(z) = z * sigmoid(z). Smooth, non-monotonic, easy to differentiate. -
*: element-wise multiplication along all axes (this is the “GLU gate”; the gated unit). -
down:nnx.Linear(d_ff, d_model). Project back tod_model.
Three linears total (vs two in the vanilla FFN). The “up” branch
carries information; the “silu(gate)” branch decides how much of
each feature to let through. Multiplicative gating gives the network
a way to express “this dimension is irrelevant for this token” by
setting the gate to ~0, no matter what up(x) says.
Why SiLU?
SiLU (Swish) is z * sigmoid(z). Compared to ReLU:
- Smooth. Derivative is continuous everywhere; ReLU’s derivative jumps at zero.
-
Non-monotonic. It dips slightly negative around
z ≈ -1.28then recovers. ReLU is rectified, so it just dies for negative inputs. -
Self-gated. Built-in
* sigmoid(z)already does a small gate, but SwiGLU adds a separate gate on top.
GeLU (Gaussian Error Linear Unit, used in BERT/GPT-2) is similar in spirit. The choice between GeLU/SiLU is mostly empirical; SiLU won in modern LLMs by a few perplexity points.
Worked sketch
class SwiGLU(nnx.Module):
def __init__(self, d_model, d_ff, rngs):
self.gate = nnx.Linear(d_model, d_ff, rngs=rngs)
self.up = nnx.Linear(d_model, d_ff, rngs=rngs)
self.down = nnx.Linear(d_ff, d_model, rngs=rngs)
def __call__(self, x):
return self.down(jax.nn.silu(self.gate(x)) * self.up(x))
Three attributes, one line. Compare with the vanilla FFN:
# Vanilla FFN
return ff2(jax.nn.relu(ff1(x)))
SwiGLU has one extra Linear (the up branch) and element-wise
multiplication in the middle. That’s it — the “innovation” is
really just the gate.
A note on d_ff
With three matrices instead of two, the FLOP cost grows. To keep
total cost roughly equal, real LLaMA-style models use
d_ff ≈ (8/3) * d_model (≈ 2.67x), instead of the usual 4*d_model.
For this problem we just take d_ff as a hyperparameter; the test
cases use d_ff = 16 with d_model = 8 for simplicity.
Common pitfalls
-
One Linear instead of two for
gate/up. If you reuse a singlennx.Linear, both branches see the SAME projection — gate and value collapse, and the gating mechanism vanishes. -
siluon the wrong branch. Convention issilu(gate), notsilu(up). Mathematically the network can compensate, but you’ll mismatch every published checkpoint. -
Forgetting the multiplication. Without
*, you’ve got two independent FFN branches summed (or one alone) — not gated. -
Using
reluorgeluand calling it SwiGLU. SwiGLU specifically uses SiLU (Swish). With ReLU it’s “ReGLU”; with GeLU it’s “GeGLU”. Different names, different empirical results. -
Wrong order of operations.
silu(gate(x)) * up(x)— element-wise product of the silu’d gate with the rawup. Don’t silu the up branch.
Problem
Write swiglu_ffn_forward(seed, x, d_model, d_ff):
-
SwiGLU(nnx.Module)withgate = nnx.Linear(d_model, d_ff),up = nnx.Linear(d_model, d_ff),down = nnx.Linear(d_ff, d_model). -
__call__(x)returnsself.down(jax.nn.silu(self.gate(x)) * self.up(x)). -
Cast
d_model,d_fffrom float to int. Buildnnx.Rngs(int(seed)). -
Return the output flattened:
out.reshape(-1).
Inputs:
-
seed: int (passed as float). -
x: 2-D(T, d_model). -
d_model,d_ff: ints (passed as floats).
Output: 1-D flattened (T * d_model,).
Hints
Sign in to attempt this problem and view the solution.