medium primitives

SwiGLU Feed-Forward Network

Why this matters

Modern LLMs (LLaMA, PaLM, Gemma, Mistral) use SwiGLU instead of the classic Dense → ReLU → Dense FFN. SwiGLU is a gated linear unit with silu (a.k.a. Swish) as the gate’s nonlinearity. It consistently outperforms ReLU FFNs at the same parameter count.

The math

Classic FFN (e.g. original Transformer, BERT, ViT):

h = Dense(d_ff)(x)
h = relu(h)        # or gelu
out = Dense(D)(h)

Two Dense layers, one nonlinearity. Three matmuls if you include the activation cost. Total params: D·d_ff + d_ff·D = 2·D·d_ff.

SwiGLU replaces this with THREE Dense layers and a multiplicative gate:

gate = Dense(d_ff)(x)
up   = Dense(d_ff)(x)
h    = silu(gate) * up           # element-wise multiply
out  = Dense(D)(h)

Three Dense layers (gate, up, down). Total params: 3·D·d_ff — 1.5× a classic FFN at the same d_ff. To match parameter counts, papers often use d_ff_swiglu = (2/3)·d_ff_relu.

What silu is

silu(x) = x * sigmoid(x). Also known as Swish (Ramachandran et al., 2017). Smooth, non-monotonic, asymptotes to 0 below and to x above. Like a smoothed ReLU. jax.nn.silu(x) computes it.

silu(0)      = 0
silu(-3)    ≈ -0.142
silu(3)     ≈ 2.857
silu(large) ≈ x

Why gating helps

The output is silu(gate(x)) * up(x) — the up projection is multiplicatively MASKED by silu(gate). The model learns to selectively suppress or amplify each up channel based on the input. Classic FFNs only have ONE pathway through the nonlinearity; SwiGLU has TWO: a “what to amplify” path (gate) and a “what content” path (up).

Empirically: same training cost, same loss curves at smaller d_ff when normalised, slight quality wins. LLaMA showed it scales cleanly to 70B+.

Worked example

D=4, d_ff=8, input x = [1, 0, 0, 1]:

  1. gate = Dense(8)(x)(8,).
  2. up = Dense(8)(x)(8,).
  3. silu(gate) * up → element-wise (8,).
  4. out = Dense(4)(...)(4,).

Common pitfalls

  • Wrong activation: it’s silu (a.k.a. Swish), not sigmoid, not gelu, not relu. jax.nn.silu(x) is the canonical call.
  • Activating up instead of gate: the gate is the one with silu, the up is plain. Reversing them changes the function.
  • One Dense for both gate and up: people sometimes fuse them into one Dense(2·d_ff) and split. That’s a valid optimisation, but the math is the same — keep them separate for clarity here.
  • Forgetting the down projection: SwiGLU still needs the third Dense(D) to bring the dim back to D. Without it, output is (d_ff,) instead of (D,).

Problem

Implement swiglu_ffn_forward(seed, x, d_ff):

  1. D = x.shape[-1]. Cast d_ff to int.
  2. Three Dense layers: gate (D → d_ff), up (D → d_ff), down (d_ff → D).
  3. h = jax.nn.silu(gate(x)) * up(x).
  4. out = down(h).
  5. Return flattened.

Build a small nn.Module (SwiGLUFFN) inside @nn.compact. Name the three Dense layers "gate", "up", "down" for readability.

Inputs:

  • seed: int.
  • x: 2-D (T, D).
  • d_ff: int FFN hidden dim.

Output: 1-D, flattened (T, D).

Hints

flax swiglu ffn modern-llm

Sign in to attempt this problem and view the solution.