medium primitives

NNX Port: Linen Dense → NNX Linear

Why this matters

The most common cross-framework operation in real life isn’t fancy bridge wrappers — it’s porting weights. You have a Linen checkpoint (or a fresh init) and you want the same numerical model in nnx. Or vice versa.

The good news: a Dense/Linear layer is the same math in both frameworks. The kernel and bias arrays have the same shape and the same role. Porting is two assignments.

Once you’ve done it for Dense, the same pattern scales: each layer type has its own slot names (kernel/bias, scale/bias, embedding, etc.), and you copy by attribute. No serialization, no converter scripts, no opinions.

Layout convention (the only thing to remember)

Both Flax frameworks use the same convention for fully-connected layers:

  • kernel: (in_features, out_features)
  • bias: (out_features,)

nnx.Linear(in_features=I, out_features=O, ...) — explicit dims at construction. nn.Dense(features=O)O only; I inferred from x at init.

But the arrays are identical layouts. Once both modules are materialized, copy is one-to-one.

The two-step port

  1. Build the Linen Dense, init it on x, get its params dict:
    linen_dense  = nn.Dense(features=F)
    linen_params = linen_dense.init(jax.random.PRNGKey(seed), x)
    # linen_params = {"params": {"kernel": (in, F), "bias": (F,)}}
  2. Build the nnx Linear with matching dims, then overwrite its Variables:
    nnx_linear = nnx.Linear(in_features=in_f, out_features=F, rngs=rngs)
    nnx_linear.kernel.value = linen_params["params"]["kernel"]
    nnx_linear.bias.value   = linen_params["params"]["bias"]

After that, nnx_linear(x) is bit-for-bit equal to linen_dense.apply(linen_params, x).

Why use a different seed + 1 for nnx init?

To make a point: it doesn’t matter what initialization the nnx Linear starts with — we’re going to overwrite both kernel.value and bias.value. Using a different seed makes it visible in tracing / debugging that the nnx module’s initial weights are not what’s being used.

In production you’d typically not bother with the nnx init at all — use jax.eval_shape and Orbax to allocate, or just accept the brief waste of a fresh init.

Common pitfalls

  • Forgetting ["params"] in the Linen dict. Linen wraps params in a top-level "params" collection (alongside other collections like "batch_stats" for BatchNorm).
  • Assigning a NumPy array. It works, but the value will be cast to JAX on next forward. Cleaner to keep everything as jax.Array throughout.
  • Trying nnx_linear.kernel = linen_params[...]. That replaces the nnx.Param wrapper with a plain array — the next forward will fail because nnx.split expects Variables. Always assign to .value.
  • Mismatched shapes. If in_features was inferred from x in Linen but you mistyped the nnx in_features, the assign succeeds (no shape check on .value =) but the forward errors. Verify dims first.

Problem

Write port_dense(seed, x, features):

  1. linen_dense = nn.Dense(features=int(features)). Init it: linen_params = linen_dense.init(jax.random.PRNGKey(int(seed)), x). Compute linen_out = linen_dense.apply(linen_params, x).
  2. nnx_linear = nnx.Linear(in_features=int(x.shape[-1]), out_features=int(features), rngs=nnx.Rngs(int(seed) + 1)). Replace its Variables with the Linen values:
    • nnx_linear.kernel.value = linen_params["params"]["kernel"].
    • nnx_linear.bias.value = linen_params["params"]["bias"].
  3. Compute nnx_out = nnx_linear(x).
  4. Return jnp.array([float(nnx_out.sum()), float(linen_out.sum())]).

Both sums must be identical.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).

Output: length-2 array [s, s] (sums equal).

Hints

flax nnx port weights interop linen

Sign in to attempt this problem and view the solution.