medium primitives

NNX Coexistence: Parity With Linen

Why this matters

The whole sequence of bridge problems comes down to one claim: nnx and Linen aren’t separate worlds. They’re two front-ends to the same JAX arrays. This problem makes that explicit by running the same params through both and checking the output difference is bit-for-bit zero.

Once you’ve verified parity, every other operation — porting, sharing, mixing, training — is just paperwork. The math agrees.

What “parity” means here

Build an nnx.Linear. Its kernel.value and bias.value are plain JAX arrays. Now build a nn.Dense and provide the same arrays as its params. Apply both to the same input.

nnx_linear = nnx.Linear(in_f, out_f, rngs=nnx.Rngs(seed))
linen_dense = nn.Dense(features=out_f)

linen_params = {
    "params": {
        "kernel": nnx_linear.kernel.value,
        "bias":   nnx_linear.bias.value,
    }
}

nnx_out   = nnx_linear(x)
linen_out = linen_dense.apply(linen_params, x)

Both compute x @ kernel + bias over identical kernel and bias arrays. Outputs are equal — jnp.max(jnp.abs(nnx_out - linen_out)) == 0.0.

What this proves

  • No hidden initialization differences. When the params come from the same source, both modules produce the same result.
  • No hidden math differences. The dot product, the bias add, the dtype handling — all the same.
  • Same memory. The two outputs are not just numerically close; they’re literal identicals. The arrays involved are the same.

The mental model “nnx is the new framework; Linen is the old framework; they don’t talk to each other” is wrong. The right model: “nnx and Linen are two ergonomic shells around the same jax.Array math; they talk by sharing arrays.”

Where coexistence matters in production

  • Test parity during migration. As you port modules from Linen to nnx, run the same input through both. If outputs diverge, you’ve made a porting error — caught immediately.
  • A/B against a Linen reference. A research baseline ships in Linen; your new code is in nnx. Coexistence tests are the regression suite that catches subtle bugs.
  • Mixed-framework debugging. When something goes wrong, you can isolate: is it the nnx module? The Linen module? Or the data? Parity testing is the diagnostic.

Common pitfalls

  • Different RNG between nnx and Linen. nnx.Rngs(seed) and jax.random.PRNGKey(seed) produce different inits. If you don’t share params explicitly, expect different outputs.
  • Forgetting 'params' outer key. Linen’s apply needs {"params": {...}}, not just {...}.
  • Subtle dtype mismatches. If one side promotes to float32 and the other stays float16, the diff won’t be exactly 0.0. For this problem we use float32 throughout, so that’s not an issue.
  • Default use_bias=True mismatch. Both nnx.Linear and nn.Dense default to use_bias=True. Make sure both agree.

Problem

Write coexistence_parity(seed, x, features):

  1. Build nnx_linear = nnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=nnx.Rngs(int(seed))).
  2. Build linen_dense = nn.Dense(features=int(features)).
  3. Construct linen_params dict pointing at nnx_linear.kernel.value and nnx_linear.bias.value.
  4. Compute nnx_out = nnx_linear(x) and linen_out = linen_dense.apply(linen_params, x).
  5. Compute abs_diff = jnp.max(jnp.abs(nnx_out - linen_out)).
  6. Return jnp.array([float(nnx_out.sum()), float(linen_out.sum()), float(abs_diff)]).

Expected: first two values equal, third is 0.0.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).

Output: length-3 array [s, s, 0.0].

Hints

flax nnx bridge interop linen parity

Sign in to attempt this problem and view the solution.