hard primitives

NNX Bridge: Call NNX From Linen

Why this matters

The reverse direction of the previous problem: you have an existing Linen training pipeline (with model.init, model.apply, optax transforms over a params pytree) and you want to drop in a new nnx layer without rewriting the whole thing.

flax.nnx.bridge.to_linen(NNXClass) returns a Linen nn.Module whose forward delegates to a fresh NNXClass(rngs=...) instance. From the outside it looks like any other Linen layer: model.init(key, x) returns a params dict; model.apply(params, x) runs forward.

The wrapper is the second half of incremental migration: write the new code as nnx, expose it to the legacy Linen pipeline as a Linen module. Both halves live happily until the migration completes.

API: bridge.to_linen

import flax.linen as nn
from flax import nnx
from flax.nnx import bridge

class MyNNX(nnx.Module):
    def __init__(self, rngs):
        self.linear = nnx.Linear(in_features=4, out_features=8, rngs=rngs)
    def __call__(self, x):
        return self.linear(x)

linen_wrapped = bridge.to_linen(MyNNX)
# `linen_wrapped` is a Linen Module — has init/apply.
key    = jax.random.PRNGKey(0)
params = linen_wrapped.init(key, jnp.ones((4,)))
y      = linen_wrapped.apply(params, jnp.array([1.0, 2.0, 3.0, 4.0]))

Under the hood, to_linen builds a Linen Module that, on first trace, instantiates MyNNX(rngs=nnx.Rngs(<linen_init_key>)), splits the resulting model into (graphdef, state), and registers the state’s arrays as Linen params. On apply, it merges (graphdef, state_from_params) and calls forward.

The seed propagation is invisible: Linen’s init key gets routed into nnx.Rngs for you. You don’t have to thread anything manually.

The constructor signature contract

to_linen(NNXClass) calls NNXClass(rngs=...) — so the nnx class’s __init__ must take a rngs parameter (and only rngs if you call to_linen with no extra args). To pass other config like out_features, define them as constructor args and pass them via to_linen(NNXClass, out_features=8) keyword args, OR close over them in a local class definition (which is what we do here for simplicity).

Worked example

def nnx_inside_linen(seed, x, features):
    F = int(features)
    in_features = int(x.shape[-1])

    class MyNNX(nnx.Module):
        def __init__(self, rngs):
            self.linear = nnx.Linear(
                in_features=in_features, out_features=F, rngs=rngs
            )
        def __call__(self, x):
            return self.linear(x)

    linen_wrapped = bridge.to_linen(MyNNX)
    key    = jax.random.PRNGKey(int(seed))
    params = linen_wrapped.init(key, x)
    return linen_wrapped.apply(params, x)

Closing over in_features and F keeps the wrapped class to a (rngs)-only constructor — to_linen invokes it with no extra args.

Why bother?

  • You inherited a Linen pipeline (Flaxformer, T5X, MaxText, EasyLM…) and want one new layer in nnx without converting the rest.
  • You’re sharing nnx-authored code with a research team still on Linen — to_linen is the export adapter.
  • You want to use Linen-only utilities (e.g., particular optax integrations, sharding helpers) on an nnx-defined module.

Common pitfalls

  • Forgetting that the nnx class’s __init__ must accept rngs. to_linen calls it with rngs=. If your __init__ only takes (self, features, rngs) you have to pass features separately.
  • Trying to access nnx state directly from params. The state is flattened into Linen params['nnx'] (or similar); structure differs from a hand-built nnx model.
  • Different RNG semantics. Linen splits the init key under the hood; the same numerical seed may produce different params than a direct nnx.Linear(... rngs=nnx.Rngs(seed)) would. That’s expected — both are valid, just routed differently.

Problem

Write nnx_inside_linen(seed, x, features):

  1. Inside the function, define a local MyNNX(nnx.Module) whose __init__(self, rngs) builds an nnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=rngs) and whose __call__(x) returns self.linear(x).
  2. Wrap it: linen_wrapped = bridge.to_linen(MyNNX).
  3. Run the standard Linen flow: params = linen_wrapped.init(key, x) with key = jax.random.PRNGKey(int(seed)); return linen_wrapped.apply(params, x).

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).

Output: 1-D array of length features.

Hints

flax nnx bridge interop linen

Sign in to attempt this problem and view the solution.