hard primitives

NNX Port: Linen Transformer Block → NNX

Why this matters

Porting a single Dense was easy; porting an MLP was bookkeeping. Porting a full Transformer encoder block is the realistic case: a nested module tree where you must walk the Linen params dict and drop each leaf into the right nnx attribute.

Once you can do this for a Transformer block, you can port any research codebase. The only project-specific piece is the mapping table — what Linen path goes to what nnx attribute. The mechanics are the same: .value =, recursively.

The block layout

Standard prenorm encoder:

h = LayerNorm(x)               # ln1
h = MultiHeadAttention(h)      # attn (q/k/v/out projections)
x = x + h
h = LayerNorm(x)               # ln2
h = Linear(d_model -> d_ff)(h) # ff1
h = relu(h)
h = Linear(d_ff -> d_model)(h) # ff2
x = x + h

To make porting clean, we use the same hand-built MHA in both Linen and nnx (four Dense projections — q_proj, k_proj, v_proj, out_proj). This way the param tree mirrors exactly:

linen_params["params"]:
  ln1:    {scale, bias}
  attn:
    q_proj:   {kernel, bias}
    k_proj:   {kernel, bias}
    v_proj:   {kernel, bias}
    out_proj: {kernel, bias}
  ln2:    {scale, bias}
  ff1:    {kernel, bias}
  ff2:    {kernel, bias}

Each Linen submodule is given an explicit name=... so that’s the key under params. (Without name=, Linen would use auto-generated Dense_0, Dense_1, … and the mapping is harder to read.)

The nnx attributes match: nnx_block.ln1, nnx_block.attn.q_proj, etc.

The port loop

p = linen_params["params"]

# LayerNorms — note 'scale' (not 'kernel'!).
nnx_block.ln1.scale.value = p["ln1"]["scale"]
nnx_block.ln1.bias.value  = p["ln1"]["bias"]
nnx_block.ln2.scale.value = p["ln2"]["scale"]
nnx_block.ln2.bias.value  = p["ln2"]["bias"]

# Attention — four projections.
for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]:
    getattr(nnx_block.attn, proj).kernel.value = p["attn"][proj]["kernel"]
    getattr(nnx_block.attn, proj).bias.value   = p["attn"][proj]["bias"]

# FFN.
nnx_block.ff1.kernel.value = p["ff1"]["kernel"]
nnx_block.ff1.bias.value   = p["ff1"]["bias"]
nnx_block.ff2.kernel.value = p["ff2"]["kernel"]
nnx_block.ff2.bias.value   = p["ff2"]["bias"]

All math is identical, so nnx_block(x) == linen_block.apply(linen_params, x).

Why match MHA layouts?

The official nn.MultiHeadDotProductAttention uses fused DenseGeneral with multi-axis kernels of shape (d_model, num_heads, head_dim). nnx has no equivalent. To port that, you’d need a reshape between the two — easy to forget, annoying to debug.

The cleaner path: implement MHA the same way in both worlds (four Dense/Linear projections), then porting is verbatim. This is also how research repos increasingly write attention since the rise of grouped/multi-query variants.

Common pitfalls

  • scale vs. kernel for LayerNorm. Linen and nnx both call the learnable scale scale (not weight, not gamma). The bias is bias. There’s no kernel.
  • Skipping attn namespace. The four projections live under params["attn"][...], not at the top level.
  • Reordering submodule declarations in Linen. Without explicit name=, Linen renames everything. Use name= for stability.
  • Forgetting that add is residual, not concat. The block output is x + h, not concat(x, h). nnx forward must mirror.

Problem

Define LinenMHA, LinenEncoderBlock, NNXMHA, NNXEncoderBlock as shown in the reference. Then write port_transformer_block(seed, x, num_heads, d_model, d_ff):

  1. Build & init the Linen block; compute linen_out.
  2. Build the nnx block with nnx.Rngs(int(seed) + 1).
  3. Port every leaf: 2 LayerNorms (scale/bias), 4 attention projections (kernel/bias each), 2 FFN linears (kernel/bias each).
  4. Compute nnx_out; return jnp.array([float(nnx_out.sum()), float(linen_out.sum())]).

Both sums must match.

Inputs:

  • seed: int (passed as float).
  • x: 2-D (T, d_model).
  • num_heads, d_model, d_ff: ints (passed as floats).

Output: length-2 array [s, s] (sums equal).

Hints

flax nnx port weights interop linen transformer

Sign in to attempt this problem and view the solution.