We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Bridge: State Translation NNX → Linen
Why this matters
nnx and Linen aren’t separate worlds — they’re two containers around
the same JAX arrays. nnx.split(model) produces a state pytree of
nnx.Variable wrappers; Linen’s params is a nested dict of plain
arrays. Translating between them is rewrapping, not conversion.
Once you can translate state in either direction, you can:
- Save an nnx model in Linen-compatible format for downstream tools.
- Load a Linen checkpoint into an nnx model.
- Run a parallel sanity check (apply the same numerical params via both frameworks; outputs must match).
The translation
Build an nnx.Linear. Use nnx.split to get its state:
nnx_linear = nnx.Linear(in_f, out_f, rngs=rngs)
graphdef, state = nnx.split(nnx_linear)
# state['kernel']: nnx.Param(value=(in_f, out_f) array)
# state['bias']: nnx.Param(value=(out_f,) array)
The state is a nnx.State pytree — addressable like a dict, where
each leaf is an nnx.Variable (nnx.Param for trainable). .value
gives the underlying jax.Array.
Linen wants:
linen_params = {
"params": {
"kernel": <jax.Array>,
"bias": <jax.Array>,
}
}
Translation is direct:
linen_params = {
"params": {
"kernel": state["kernel"].value,
"bias": state["bias"].value,
}
}
Now nn.Dense(out_f).apply(linen_params, x) produces the same
output as nnx_linear(x) — they’re using the same arrays.
Why “isomorphic” not “identical”
state["kernel"] is an nnx.Param wrapper that records metadata
(the variable type, sharding hints, dtype, attributes). Linen’s
params["kernel"] is just the array.
The wrappers exist for different reasons:
-
nnx.Variable: marks the array’s role (trainable param, BN
stat, RNG key, KV cache). nnx APIs use this to filter — e.g.
nnx.split(model, nnx.Param)keeps only trainable. -
Linen
paramsdict: separates collections (params,batch_stats,cache) — same goal, different layer.
Translating drops the wrappers but preserves the arrays. If you
needed BN statistics, you’d put them under params["batch_stats"]
on the Linen side; on the nnx side they’d be nnx.BatchStat
variables filtered by nnx.split(model, nnx.BatchStat).
Why care about this for Linear/Dense?
For a single Linear, the translation is trivial. For a full Transformer with LayerNorms, attention sublayers, FFNs, the translation is systematic — the same recipe at every leaf:
def state_to_linen(state):
# state is a nested State; recurse, replacing each nnx.Variable
# with .value and wrapping the top-level dict under "params".
...
Doing it once for Linear makes the recipe concrete.
Common pitfalls
-
Forgetting
.valueon the wrapper.state["kernel"]is a Variable, not the array; you can’t pass it directly tonn.Dense.applywithout unwrapping. -
Forgetting the
"params"outer key. Linen’sapplyexpects a collection-keyed dict. Without"params"the Dense will complain about a missing collection. -
Trying to keep variable metadata. Linen doesn’t carry it. If
you need shardings or types post-translation, you re-attach them
separately via
flax.linen.with_partitioningetc. -
Different layout convention for some layer types. Both
nnx.Linearandnn.Denseuse kernel(in, out). Some other layers (e.g.,nn.MultiHeadDotProductAttention) use multi-axis kernels — those don’t translate via simple unwrap.
Problem
Write state_translation(seed, x, features):
-
Build
nnx_linear = nnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=nnx.Rngs(int(seed))). Computennx_out = nnx_linear(x). -
Split:
graphdef, state = nnx.split(nnx_linear). -
Translate to Linen format:
linen_params = { "params": { "kernel": state["kernel"].value, "bias": state["bias"].value, } } -
Apply via Linen:
linen_dense = nn.Dense(features=int(features)); linen_out = linen_dense.apply(linen_params, x). -
Return
jnp.array([float(nnx_out.sum()), float(linen_out.sum())]).
Sums must be equal.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float).
Output: length-2 array [s, s].
Hints
Sign in to attempt this problem and view the solution.