medium primitives

NNX Bridge: State Translation NNX → Linen

Why this matters

nnx and Linen aren’t separate worlds — they’re two containers around the same JAX arrays. nnx.split(model) produces a state pytree of nnx.Variable wrappers; Linen’s params is a nested dict of plain arrays. Translating between them is rewrapping, not conversion.

Once you can translate state in either direction, you can:

  • Save an nnx model in Linen-compatible format for downstream tools.
  • Load a Linen checkpoint into an nnx model.
  • Run a parallel sanity check (apply the same numerical params via both frameworks; outputs must match).

The translation

Build an nnx.Linear. Use nnx.split to get its state:

nnx_linear = nnx.Linear(in_f, out_f, rngs=rngs)
graphdef, state = nnx.split(nnx_linear)
# state['kernel']: nnx.Param(value=(in_f, out_f) array)
# state['bias']:   nnx.Param(value=(out_f,)        array)

The state is a nnx.State pytree — addressable like a dict, where each leaf is an nnx.Variable (nnx.Param for trainable). .value gives the underlying jax.Array.

Linen wants:

linen_params = {
    "params": {
        "kernel": <jax.Array>,
        "bias":   <jax.Array>,
    }
}

Translation is direct:

linen_params = {
    "params": {
        "kernel": state["kernel"].value,
        "bias":   state["bias"].value,
    }
}

Now nn.Dense(out_f).apply(linen_params, x) produces the same output as nnx_linear(x) — they’re using the same arrays.

Why “isomorphic” not “identical”

state["kernel"] is an nnx.Param wrapper that records metadata (the variable type, sharding hints, dtype, attributes). Linen’s params["kernel"] is just the array.

The wrappers exist for different reasons:

  • nnx.Variable: marks the array’s role (trainable param, BN stat, RNG key, KV cache). nnx APIs use this to filter — e.g. nnx.split(model, nnx.Param) keeps only trainable.
  • Linen params dict: separates collections (params, batch_stats, cache) — same goal, different layer.

Translating drops the wrappers but preserves the arrays. If you needed BN statistics, you’d put them under params["batch_stats"] on the Linen side; on the nnx side they’d be nnx.BatchStat variables filtered by nnx.split(model, nnx.BatchStat).

Why care about this for Linear/Dense?

For a single Linear, the translation is trivial. For a full Transformer with LayerNorms, attention sublayers, FFNs, the translation is systematic — the same recipe at every leaf:

def state_to_linen(state):
    # state is a nested State; recurse, replacing each nnx.Variable
    # with .value and wrapping the top-level dict under "params".
    ...

Doing it once for Linear makes the recipe concrete.

Common pitfalls

  • Forgetting .value on the wrapper. state["kernel"] is a Variable, not the array; you can’t pass it directly to nn.Dense.apply without unwrapping.
  • Forgetting the "params" outer key. Linen’s apply expects a collection-keyed dict. Without "params" the Dense will complain about a missing collection.
  • Trying to keep variable metadata. Linen doesn’t carry it. If you need shardings or types post-translation, you re-attach them separately via flax.linen.with_partitioning etc.
  • Different layout convention for some layer types. Both nnx.Linear and nn.Dense use kernel (in, out). Some other layers (e.g., nn.MultiHeadDotProductAttention) use multi-axis kernels — those don’t translate via simple unwrap.

Problem

Write state_translation(seed, x, features):

  1. Build nnx_linear = nnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=nnx.Rngs(int(seed))). Compute nnx_out = nnx_linear(x).
  2. Split: graphdef, state = nnx.split(nnx_linear).
  3. Translate to Linen format:
    linen_params = {
        "params": {
            "kernel": state["kernel"].value,
            "bias":   state["bias"].value,
        }
    }
  4. Apply via Linen: linen_dense = nn.Dense(features=int(features)); linen_out = linen_dense.apply(linen_params, x).
  5. Return jnp.array([float(nnx_out.sum()), float(linen_out.sum())]).

Sums must be equal.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).

Output: length-2 array [s, s].

Hints

flax nnx bridge interop linen split state

Sign in to attempt this problem and view the solution.