medium primitives

NNX SwiGLU FFN

Why this matters

In the original Transformer (and the encoder/decoder blocks in pos 41/42), the position-wise FFN is Dense -> ReLU -> Dense. Modern LLMs — LLaMA, PaLM, Mistral, DeepSeek — replace this with SwiGLU: a GATED feed-forward block where one branch modulates the other via element-wise multiplication.

SwiGLU costs ~50% more parameters and FLOPs than the vanilla FFN, but it consistently improves perplexity. Once you know how it works, swapping it into a transformer block is one line of code — that’s why it’s everywhere now.

The formula

For input x of shape (T, d_model):

SwiGLU(x) = down( silu(gate(x)) * up(x) )

where:

  • gate, up: each nnx.Linear(d_model, d_ff). Two separate projections to the wider hidden width.
  • silu: a.k.a. the swish nonlinearity, silu(z) = z * sigmoid(z). Smooth, non-monotonic, easy to differentiate.
  • *: element-wise multiplication along all axes (this is the “GLU gate”; the gated unit).
  • down: nnx.Linear(d_ff, d_model). Project back to d_model.

Three linears total (vs two in the vanilla FFN). The “up” branch carries information; the “silu(gate)” branch decides how much of each feature to let through. Multiplicative gating gives the network a way to express “this dimension is irrelevant for this token” by setting the gate to ~0, no matter what up(x) says.

Why SiLU?

SiLU (Swish) is z * sigmoid(z). Compared to ReLU:

  • Smooth. Derivative is continuous everywhere; ReLU’s derivative jumps at zero.
  • Non-monotonic. It dips slightly negative around z ≈ -1.28 then recovers. ReLU is rectified, so it just dies for negative inputs.
  • Self-gated. Built-in * sigmoid(z) already does a small gate, but SwiGLU adds a separate gate on top.

GeLU (Gaussian Error Linear Unit, used in BERT/GPT-2) is similar in spirit. The choice between GeLU/SiLU is mostly empirical; SiLU won in modern LLMs by a few perplexity points.

Worked sketch

class SwiGLU(nnx.Module):
    def __init__(self, d_model, d_ff, rngs):
        self.gate = nnx.Linear(d_model, d_ff, rngs=rngs)
        self.up   = nnx.Linear(d_model, d_ff, rngs=rngs)
        self.down = nnx.Linear(d_ff, d_model, rngs=rngs)

    def __call__(self, x):
        return self.down(jax.nn.silu(self.gate(x)) * self.up(x))

Three attributes, one line. Compare with the vanilla FFN:

# Vanilla FFN
return ff2(jax.nn.relu(ff1(x)))

SwiGLU has one extra Linear (the up branch) and element-wise multiplication in the middle. That’s it — the “innovation” is really just the gate.

A note on d_ff

With three matrices instead of two, the FLOP cost grows. To keep total cost roughly equal, real LLaMA-style models use d_ff ≈ (8/3) * d_model (≈ 2.67x), instead of the usual 4*d_model. For this problem we just take d_ff as a hyperparameter; the test cases use d_ff = 16 with d_model = 8 for simplicity.

Common pitfalls

  • One Linear instead of two for gate/up. If you reuse a single nnx.Linear, both branches see the SAME projection — gate and value collapse, and the gating mechanism vanishes.
  • silu on the wrong branch. Convention is silu(gate), not silu(up). Mathematically the network can compensate, but you’ll mismatch every published checkpoint.
  • Forgetting the multiplication. Without *, you’ve got two independent FFN branches summed (or one alone) — not gated.
  • Using relu or gelu and calling it SwiGLU. SwiGLU specifically uses SiLU (Swish). With ReLU it’s “ReGLU”; with GeLU it’s “GeGLU”. Different names, different empirical results.
  • Wrong order of operations. silu(gate(x)) * up(x) — element-wise product of the silu’d gate with the raw up. Don’t silu the up branch.

Problem

Write swiglu_ffn_forward(seed, x, d_model, d_ff):

  1. SwiGLU(nnx.Module) with gate = nnx.Linear(d_model, d_ff), up = nnx.Linear(d_model, d_ff), down = nnx.Linear(d_ff, d_model).
  2. __call__(x) returns self.down(jax.nn.silu(self.gate(x)) * self.up(x)).
  3. Cast d_model, d_ff from float to int. Build nnx.Rngs(int(seed)).
  4. Return the output flattened: out.reshape(-1).

Inputs:

  • seed: int (passed as float).
  • x: 2-D (T, d_model).
  • d_model, d_ff: ints (passed as floats).

Output: 1-D flattened (T * d_model,).

Hints

flax nnx ffn swiglu gated llama architecture

Sign in to attempt this problem and view the solution.