hard primitives

NNX Bridge: Shared Params Across nnx and Linen

Why this matters

nnx and Linen feel like different worlds — different APIs, different abstractions, different mental models. But the params underneath are just JAX arrays. There’s no nnx.Tensor or linen.Tensor. Both frameworks store kernels and biases as plain jax.Arrays; the rest is just bookkeeping.

Once you internalize that, “sharing parameters across the two frameworks” stops being magic. You point at the array. That’s it.

The setup

Build an nnx.Linear and a flax.linen.Dense with the same (in_features, out_features) shape. The Linen Dense lazily declares its kernel via self.param; the nnx Linear holds it as nnx_linear.kernel: nnx.Param. Both kernels are arrays of shape (in, out)the exact same layout.

To share: take the nnx kernel array and use it as the value in the Linen params dict.

nnx_linear = nnx.Linear(in_f, out_f, rngs=nnx.Rngs(seed), use_bias=False)
linen_dense = nn.Dense(features=out_f, use_bias=False)

# Treat the nnx kernel as canonical.
shared_kernel = nnx_linear.kernel.value          # (in_f, out_f)
linen_params  = {"params": {"kernel": shared_kernel}}

# Both forward passes now use the SAME array.
nnx_y   = nnx_linear(x)
linen_y = linen_dense.apply(linen_params, x)

Adding nnx_y + linen_y produces 2 * (x @ shared_kernel) — proof that both halves agree on which kernel they’re using.

Why this is non-obvious

Newcomers often assume nnx params and Linen params are different types, or that you need a converter. They aren’t, and you don’t.

  • nnx.Param(value=arr).value — the underlying JAX array.
  • params["params"]["kernel"] — the same kind of underlying JAX array.

The wrappers (nnx.Param, the Linen params dict) are containers that record metadata (shapes, dtypes, sharding hints, names). Pull the array out of one container and stuff it into the other. The math doesn’t care which container you came from.

Why share params at all?

Common cases:

  • Weight tying in language models: input embedding and output projection share the same (vocab, d_model) matrix. If your embedding is Linen and your output head is nnx (or vice versa), you’d otherwise have to choose one framework or maintain two copies.
  • Multi-head ensembles where one head is an nnx model and another is a Linen model wrapping a different architecture, but both use a shared backbone embedding.
  • Checkpoint surgery: load a Linen checkpoint, point an nnx module at the array. No conversion script needed.

Common pitfalls

  • Different bias conventions. nnx.Linear and nn.Dense both have use_bias=True by default; if you share a kernel but use independent biases, the result is asymmetric. We disable bias for clarity.
  • Kernel shape mismatch. Both store (in, out) for plain Linear/Dense — same convention. (HuggingFace PyTorch ships (out, in) — that’s a different problem; see pos 79.)
  • Forgetting .value. nnx_linear.kernel is the nnx.Param wrapper; .value (or indexing the wrapper) gives the array.
  • Mutating the shared array. JAX arrays are immutable, so this isn’t a real risk — but if you assign a new array back to one side, the other side won’t see the change. Sharing is by reference, not by mirror.

Problem

Write shared_params_two_modules(seed, x, features):

  1. Build nnx_linear = nnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=nnx.Rngs(int(seed)), use_bias=False).
  2. Build linen_dense = nn.Dense(features=int(features), use_bias=False).
  3. Extract shared_kernel = nnx_linear.kernel.value.
  4. Build the Linen params dict: {"params": {"kernel": shared_kernel}}.
  5. Compute nnx_out = nnx_linear(x) and linen_out = linen_dense.apply(<params>, x).
  6. Return nnx_out + linen_out.

Since both modules use the same kernel and same input, linen_out == nnx_out, so the sum is 2 * nnx_out.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).

Output: 1-D array of length features (sum of nnx and Linen outputs, equal to twice either one).

Hints

flax nnx bridge interop linen weight-sharing

Sign in to attempt this problem and view the solution.