We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Port: Linen Dense → NNX Linear
Why this matters
The most common cross-framework operation in real life isn’t fancy bridge wrappers — it’s porting weights. You have a Linen checkpoint (or a fresh init) and you want the same numerical model in nnx. Or vice versa.
The good news: a Dense/Linear layer is the same math in both
frameworks. The kernel and bias arrays have the same shape and the
same role. Porting is two assignments.
Once you’ve done it for Dense, the same pattern scales: each layer
type has its own slot names (kernel/bias, scale/bias,
embedding, etc.), and you copy by attribute. No serialization, no
converter scripts, no opinions.
Layout convention (the only thing to remember)
Both Flax frameworks use the same convention for fully-connected layers:
-
kernel:
(in_features, out_features) -
bias:
(out_features,)
nnx.Linear(in_features=I, out_features=O, ...) — explicit dims at
construction.
nn.Dense(features=O) — O only; I inferred from x at init.
But the arrays are identical layouts. Once both modules are materialized, copy is one-to-one.
The two-step port
-
Build the Linen Dense, init it on
x, get itsparamsdict:linen_dense = nn.Dense(features=F) linen_params = linen_dense.init(jax.random.PRNGKey(seed), x) # linen_params = {"params": {"kernel": (in, F), "bias": (F,)}} -
Build the nnx Linear with matching dims, then overwrite its
Variables:
nnx_linear = nnx.Linear(in_features=in_f, out_features=F, rngs=rngs) nnx_linear.kernel.value = linen_params["params"]["kernel"] nnx_linear.bias.value = linen_params["params"]["bias"]
After that, nnx_linear(x) is bit-for-bit equal to
linen_dense.apply(linen_params, x).
Why use a different seed + 1 for nnx init?
To make a point: it doesn’t matter what initialization the nnx Linear
starts with — we’re going to overwrite both kernel.value and
bias.value. Using a different seed makes it visible in tracing /
debugging that the nnx module’s initial weights are not what’s
being used.
In production you’d typically not bother with the nnx init at all —
use jax.eval_shape and Orbax to allocate, or just accept the brief
waste of a fresh init.
Common pitfalls
-
Forgetting
["params"]in the Linen dict. Linen wraps params in a top-level"params"collection (alongside other collections like"batch_stats"for BatchNorm). -
Assigning a NumPy array. It works, but the value will be cast
to JAX on next forward. Cleaner to keep everything as
jax.Arraythroughout. -
Trying
nnx_linear.kernel = linen_params[...]. That replaces thennx.Paramwrapper with a plain array — the next forward will fail becausennx.splitexpects Variables. Always assign to.value. -
Mismatched shapes. If
in_featureswas inferred fromxin Linen but you mistyped the nnxin_features, the assign succeeds (no shape check on.value =) but the forward errors. Verify dims first.
Problem
Write port_dense(seed, x, features):
-
linen_dense = nn.Dense(features=int(features)). Init it:linen_params = linen_dense.init(jax.random.PRNGKey(int(seed)), x). Computelinen_out = linen_dense.apply(linen_params, x). -
nnx_linear = nnx.Linear(in_features=int(x.shape[-1]), out_features=int(features), rngs=nnx.Rngs(int(seed) + 1)). Replace its Variables with the Linen values:-
nnx_linear.kernel.value = linen_params["params"]["kernel"]. -
nnx_linear.bias.value = linen_params["params"]["bias"].
-
-
Compute
nnx_out = nnx_linear(x). -
Return
jnp.array([float(nnx_out.sum()), float(linen_out.sum())]).
Both sums must be identical.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float).
Output: length-2 array [s, s] (sums equal).
Hints
Sign in to attempt this problem and view the solution.