We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
SwiGLU Feed-Forward Network
Why this matters
Modern LLMs (LLaMA, PaLM, Gemma, Mistral) use SwiGLU instead of the
classic Dense → ReLU → Dense FFN. SwiGLU is a gated linear
unit with silu (a.k.a. Swish) as the gate’s nonlinearity. It
consistently outperforms ReLU FFNs at the same parameter count.
The math
Classic FFN (e.g. original Transformer, BERT, ViT):
h = Dense(d_ff)(x)
h = relu(h) # or gelu
out = Dense(D)(h)
Two Dense layers, one nonlinearity. Three matmuls if you include
the activation cost. Total params: D·d_ff + d_ff·D = 2·D·d_ff.
SwiGLU replaces this with THREE Dense layers and a multiplicative gate:
gate = Dense(d_ff)(x)
up = Dense(d_ff)(x)
h = silu(gate) * up # element-wise multiply
out = Dense(D)(h)
Three Dense layers (gate, up, down). Total params:
3·D·d_ff — 1.5× a classic FFN at the same d_ff. To match
parameter counts, papers often use d_ff_swiglu = (2/3)·d_ff_relu.
What silu is
silu(x) = x * sigmoid(x). Also known as Swish (Ramachandran
et al., 2017). Smooth, non-monotonic, asymptotes to 0 below and to
x above. Like a smoothed ReLU. jax.nn.silu(x) computes it.
silu(0) = 0
silu(-3) ≈ -0.142
silu(3) ≈ 2.857
silu(large) ≈ x
Why gating helps
The output is silu(gate(x)) * up(x) — the up projection is
multiplicatively MASKED by silu(gate). The model learns to
selectively suppress or amplify each up channel based on the
input. Classic FFNs only have ONE pathway through the
nonlinearity; SwiGLU has TWO: a “what to amplify” path (gate)
and a “what content” path (up).
Empirically: same training cost, same loss curves at smaller d_ff
when normalised, slight quality wins. LLaMA showed it scales
cleanly to 70B+.
Worked example
D=4, d_ff=8, input x = [1, 0, 0, 1]:
-
gate = Dense(8)(x)→(8,). -
up = Dense(8)(x)→(8,). -
silu(gate) * up→ element-wise(8,). -
out = Dense(4)(...)→(4,).
Common pitfalls
-
Wrong activation: it’s
silu(a.k.a. Swish), notsigmoid, notgelu, notrelu.jax.nn.silu(x)is the canonical call. -
Activating
upinstead ofgate: the gate is the one withsilu, the up is plain. Reversing them changes the function. -
One Dense for both gate and up: people sometimes fuse them
into one
Dense(2·d_ff)and split. That’s a valid optimisation, but the math is the same — keep them separate for clarity here. -
Forgetting the down projection: SwiGLU still needs the third
Dense(D)to bring the dim back toD. Without it, output is(d_ff,)instead of(D,).
Problem
Implement swiglu_ffn_forward(seed, x, d_ff):
-
D = x.shape[-1]. Castd_ffto int. -
Three Dense layers:
gate (D → d_ff),up (D → d_ff),down (d_ff → D). -
h = jax.nn.silu(gate(x)) * up(x). -
out = down(h). - Return flattened.
Build a small nn.Module (SwiGLUFFN) inside @nn.compact. Name
the three Dense layers "gate", "up", "down" for readability.
Inputs:
-
seed: int. -
x: 2-D(T, D). -
d_ff: int FFN hidden dim.
Output: 1-D, flattened (T, D).
Hints
Sign in to attempt this problem and view the solution.