We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Bridge: Shared Params Across nnx and Linen
Why this matters
nnx and Linen feel like different worlds — different APIs, different
abstractions, different mental models. But the params underneath are
just JAX arrays. There’s no nnx.Tensor or linen.Tensor. Both
frameworks store kernels and biases as plain jax.Arrays; the rest is
just bookkeeping.
Once you internalize that, “sharing parameters across the two frameworks” stops being magic. You point at the array. That’s it.
The setup
Build an nnx.Linear and a flax.linen.Dense with the same
(in_features, out_features) shape. The Linen Dense lazily declares
its kernel via self.param; the nnx Linear holds it as
nnx_linear.kernel: nnx.Param. Both kernels are arrays of shape
(in, out) — the exact same layout.
To share: take the nnx kernel array and use it as the value in the
Linen params dict.
nnx_linear = nnx.Linear(in_f, out_f, rngs=nnx.Rngs(seed), use_bias=False)
linen_dense = nn.Dense(features=out_f, use_bias=False)
# Treat the nnx kernel as canonical.
shared_kernel = nnx_linear.kernel.value # (in_f, out_f)
linen_params = {"params": {"kernel": shared_kernel}}
# Both forward passes now use the SAME array.
nnx_y = nnx_linear(x)
linen_y = linen_dense.apply(linen_params, x)
Adding nnx_y + linen_y produces 2 * (x @ shared_kernel) — proof
that both halves agree on which kernel they’re using.
Why this is non-obvious
Newcomers often assume nnx params and Linen params are different types, or that you need a converter. They aren’t, and you don’t.
-
nnx.Param(value=arr).value— the underlying JAX array. -
params["params"]["kernel"]— the same kind of underlying JAX array.
The wrappers (nnx.Param, the Linen params dict) are containers
that record metadata (shapes, dtypes, sharding hints, names). Pull the
array out of one container and stuff it into the other. The math
doesn’t care which container you came from.
Why share params at all?
Common cases:
-
Weight tying in language models: input embedding and output
projection share the same
(vocab, d_model)matrix. If your embedding is Linen and your output head is nnx (or vice versa), you’d otherwise have to choose one framework or maintain two copies. - Multi-head ensembles where one head is an nnx model and another is a Linen model wrapping a different architecture, but both use a shared backbone embedding.
- Checkpoint surgery: load a Linen checkpoint, point an nnx module at the array. No conversion script needed.
Common pitfalls
-
Different bias conventions.
nnx.Linearandnn.Denseboth haveuse_bias=Trueby default; if you share a kernel but use independent biases, the result is asymmetric. We disable bias for clarity. -
Kernel shape mismatch. Both store
(in, out)for plain Linear/Dense — same convention. (HuggingFace PyTorch ships(out, in)— that’s a different problem; see pos 79.) -
Forgetting
.value.nnx_linear.kernelis thennx.Paramwrapper;.value(or indexing the wrapper) gives the array. - Mutating the shared array. JAX arrays are immutable, so this isn’t a real risk — but if you assign a new array back to one side, the other side won’t see the change. Sharing is by reference, not by mirror.
Problem
Write shared_params_two_modules(seed, x, features):
-
Build
nnx_linear = nnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=nnx.Rngs(int(seed)), use_bias=False). -
Build
linen_dense = nn.Dense(features=int(features), use_bias=False). -
Extract
shared_kernel = nnx_linear.kernel.value. -
Build the Linen params dict:
{"params": {"kernel": shared_kernel}}. -
Compute
nnx_out = nnx_linear(x)andlinen_out = linen_dense.apply(<params>, x). -
Return
nnx_out + linen_out.
Since both modules use the same kernel and same input, linen_out == nnx_out, so the sum is 2 * nnx_out.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float).
Output: 1-D array of length features (sum of nnx and Linen
outputs, equal to twice either one).
Hints
Sign in to attempt this problem and view the solution.