We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Port: Linen Transformer Block → NNX
Why this matters
Porting a single Dense was easy; porting an MLP was bookkeeping.
Porting a full Transformer encoder block is the realistic case: a
nested module tree where you must walk the Linen params dict and
drop each leaf into the right nnx attribute.
Once you can do this for a Transformer block, you can port any
research codebase. The only project-specific piece is the mapping
table — what Linen path goes to what nnx attribute. The mechanics
are the same: .value =, recursively.
The block layout
Standard prenorm encoder:
h = LayerNorm(x) # ln1
h = MultiHeadAttention(h) # attn (q/k/v/out projections)
x = x + h
h = LayerNorm(x) # ln2
h = Linear(d_model -> d_ff)(h) # ff1
h = relu(h)
h = Linear(d_ff -> d_model)(h) # ff2
x = x + h
To make porting clean, we use the same hand-built MHA in both
Linen and nnx (four Dense projections — q_proj, k_proj,
v_proj, out_proj). This way the param tree mirrors exactly:
linen_params["params"]:
ln1: {scale, bias}
attn:
q_proj: {kernel, bias}
k_proj: {kernel, bias}
v_proj: {kernel, bias}
out_proj: {kernel, bias}
ln2: {scale, bias}
ff1: {kernel, bias}
ff2: {kernel, bias}
Each Linen submodule is given an explicit name=... so that’s the
key under params. (Without name=, Linen would use auto-generated
Dense_0, Dense_1, … and the mapping is harder to read.)
The nnx attributes match: nnx_block.ln1, nnx_block.attn.q_proj,
etc.
The port loop
p = linen_params["params"]
# LayerNorms — note 'scale' (not 'kernel'!).
nnx_block.ln1.scale.value = p["ln1"]["scale"]
nnx_block.ln1.bias.value = p["ln1"]["bias"]
nnx_block.ln2.scale.value = p["ln2"]["scale"]
nnx_block.ln2.bias.value = p["ln2"]["bias"]
# Attention — four projections.
for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]:
getattr(nnx_block.attn, proj).kernel.value = p["attn"][proj]["kernel"]
getattr(nnx_block.attn, proj).bias.value = p["attn"][proj]["bias"]
# FFN.
nnx_block.ff1.kernel.value = p["ff1"]["kernel"]
nnx_block.ff1.bias.value = p["ff1"]["bias"]
nnx_block.ff2.kernel.value = p["ff2"]["kernel"]
nnx_block.ff2.bias.value = p["ff2"]["bias"]
All math is identical, so nnx_block(x) == linen_block.apply(linen_params, x).
Why match MHA layouts?
The official nn.MultiHeadDotProductAttention uses fused
DenseGeneral with multi-axis kernels of shape
(d_model, num_heads, head_dim). nnx has no equivalent. To port
that, you’d need a reshape between the two — easy to forget,
annoying to debug.
The cleaner path: implement MHA the same way in both worlds (four
Dense/Linear projections), then porting is verbatim. This is
also how research repos increasingly write attention since the rise
of grouped/multi-query variants.
Common pitfalls
-
scalevs.kernelfor LayerNorm. Linen and nnx both call the learnable scalescale(notweight, notgamma). The bias isbias. There’s no kernel. -
Skipping
attnnamespace. The four projections live underparams["attn"][...], not at the top level. -
Reordering submodule declarations in Linen. Without explicit
name=, Linen renames everything. Usename=for stability. -
Forgetting that
addis residual, not concat. The block output isx + h, notconcat(x, h). nnx forward must mirror.
Problem
Define LinenMHA, LinenEncoderBlock, NNXMHA, NNXEncoderBlock
as shown in the reference. Then write port_transformer_block(seed, x, num_heads, d_model, d_ff):
-
Build & init the Linen block; compute
linen_out. -
Build the nnx block with
nnx.Rngs(int(seed) + 1). - Port every leaf: 2 LayerNorms (scale/bias), 4 attention projections (kernel/bias each), 2 FFN linears (kernel/bias each).
-
Compute
nnx_out; returnjnp.array([float(nnx_out.sum()), float(linen_out.sum())]).
Both sums must match.
Inputs:
-
seed: int (passed as float). -
x: 2-D(T, d_model). -
num_heads,d_model,d_ff: ints (passed as floats).
Output: length-2 array [s, s] (sums equal).
Hints
Sign in to attempt this problem and view the solution.