We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Coexistence: Parity With Linen
Why this matters
The whole sequence of bridge problems comes down to one claim: nnx and Linen aren’t separate worlds. They’re two front-ends to the same JAX arrays. This problem makes that explicit by running the same params through both and checking the output difference is bit-for-bit zero.
Once you’ve verified parity, every other operation — porting, sharing, mixing, training — is just paperwork. The math agrees.
What “parity” means here
Build an nnx.Linear. Its kernel.value and bias.value are
plain JAX arrays. Now build a nn.Dense and provide the same
arrays as its params. Apply both to the same input.
nnx_linear = nnx.Linear(in_f, out_f, rngs=nnx.Rngs(seed))
linen_dense = nn.Dense(features=out_f)
linen_params = {
"params": {
"kernel": nnx_linear.kernel.value,
"bias": nnx_linear.bias.value,
}
}
nnx_out = nnx_linear(x)
linen_out = linen_dense.apply(linen_params, x)
Both compute x @ kernel + bias over identical kernel and bias
arrays. Outputs are equal — jnp.max(jnp.abs(nnx_out - linen_out)) == 0.0.
What this proves
- No hidden initialization differences. When the params come from the same source, both modules produce the same result.
- No hidden math differences. The dot product, the bias add, the dtype handling — all the same.
- Same memory. The two outputs are not just numerically close; they’re literal identicals. The arrays involved are the same.
The mental model “nnx is the new framework; Linen is the old
framework; they don’t talk to each other” is wrong. The right
model: “nnx and Linen are two ergonomic shells around the same
jax.Array math; they talk by sharing arrays.”
Where coexistence matters in production
- Test parity during migration. As you port modules from Linen to nnx, run the same input through both. If outputs diverge, you’ve made a porting error — caught immediately.
- A/B against a Linen reference. A research baseline ships in Linen; your new code is in nnx. Coexistence tests are the regression suite that catches subtle bugs.
- Mixed-framework debugging. When something goes wrong, you can isolate: is it the nnx module? The Linen module? Or the data? Parity testing is the diagnostic.
Common pitfalls
-
Different RNG between nnx and Linen.
nnx.Rngs(seed)andjax.random.PRNGKey(seed)produce different inits. If you don’t share params explicitly, expect different outputs. -
Forgetting
'params'outer key. Linen’s apply needs{"params": {...}}, not just{...}. - Subtle dtype mismatches. If one side promotes to float32 and the other stays float16, the diff won’t be exactly 0.0. For this problem we use float32 throughout, so that’s not an issue.
-
Default
use_bias=Truemismatch. Bothnnx.Linearandnn.Densedefault touse_bias=True. Make sure both agree.
Problem
Write coexistence_parity(seed, x, features):
-
Build
nnx_linear = nnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=nnx.Rngs(int(seed))). -
Build
linen_dense = nn.Dense(features=int(features)). -
Construct
linen_paramsdict pointing atnnx_linear.kernel.valueandnnx_linear.bias.value. -
Compute
nnx_out = nnx_linear(x)andlinen_out = linen_dense.apply(linen_params, x). -
Compute
abs_diff = jnp.max(jnp.abs(nnx_out - linen_out)). -
Return
jnp.array([float(nnx_out.sum()), float(linen_out.sum()), float(abs_diff)]).
Expected: first two values equal, third is 0.0.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float).
Output: length-3 array [s, s, 0.0].
Hints
Sign in to attempt this problem and view the solution.