We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Bridge: Call NNX From Linen
Why this matters
The reverse direction of the previous problem: you have an existing
Linen training pipeline (with model.init, model.apply, optax
transforms over a params pytree) and you want to drop in a new nnx
layer without rewriting the whole thing.
flax.nnx.bridge.to_linen(NNXClass) returns a Linen nn.Module whose
forward delegates to a fresh NNXClass(rngs=...) instance. From the
outside it looks like any other Linen layer: model.init(key, x)
returns a params dict; model.apply(params, x) runs forward.
The wrapper is the second half of incremental migration: write the new code as nnx, expose it to the legacy Linen pipeline as a Linen module. Both halves live happily until the migration completes.
API: bridge.to_linen
import flax.linen as nn
from flax import nnx
from flax.nnx import bridge
class MyNNX(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(in_features=4, out_features=8, rngs=rngs)
def __call__(self, x):
return self.linear(x)
linen_wrapped = bridge.to_linen(MyNNX)
# `linen_wrapped` is a Linen Module — has init/apply.
key = jax.random.PRNGKey(0)
params = linen_wrapped.init(key, jnp.ones((4,)))
y = linen_wrapped.apply(params, jnp.array([1.0, 2.0, 3.0, 4.0]))
Under the hood, to_linen builds a Linen Module that, on first trace,
instantiates MyNNX(rngs=nnx.Rngs(<linen_init_key>)), splits the
resulting model into (graphdef, state), and registers the state’s
arrays as Linen params. On apply, it merges (graphdef, state_from_params) and calls forward.
The seed propagation is invisible: Linen’s init key gets routed into
nnx.Rngs for you. You don’t have to thread anything manually.
The constructor signature contract
to_linen(NNXClass) calls NNXClass(rngs=...) — so the nnx class’s
__init__ must take a rngs parameter (and only rngs if you call
to_linen with no extra args). To pass other config like
out_features, define them as constructor args and pass them via
to_linen(NNXClass, out_features=8) keyword args, OR close over them
in a local class definition (which is what we do here for simplicity).
Worked example
def nnx_inside_linen(seed, x, features):
F = int(features)
in_features = int(x.shape[-1])
class MyNNX(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(
in_features=in_features, out_features=F, rngs=rngs
)
def __call__(self, x):
return self.linear(x)
linen_wrapped = bridge.to_linen(MyNNX)
key = jax.random.PRNGKey(int(seed))
params = linen_wrapped.init(key, x)
return linen_wrapped.apply(params, x)
Closing over in_features and F keeps the wrapped class to a
(rngs)-only constructor — to_linen invokes it with no extra args.
Why bother?
- You inherited a Linen pipeline (Flaxformer, T5X, MaxText, EasyLM…) and want one new layer in nnx without converting the rest.
-
You’re sharing nnx-authored code with a research team still on
Linen —
to_linenis the export adapter. - You want to use Linen-only utilities (e.g., particular optax integrations, sharding helpers) on an nnx-defined module.
Common pitfalls
-
Forgetting that the nnx class’s
__init__must acceptrngs.to_linencalls it withrngs=. If your__init__only takes(self, features, rngs)you have to passfeaturesseparately. -
Trying to access nnx state directly from
params. The state is flattened into Linenparams['nnx'](or similar); structure differs from a hand-built nnx model. -
Different RNG semantics. Linen splits the init key under the
hood; the same numerical seed may produce different params than a
direct
nnx.Linear(... rngs=nnx.Rngs(seed))would. That’s expected — both are valid, just routed differently.
Problem
Write nnx_inside_linen(seed, x, features):
-
Inside the function, define a local
MyNNX(nnx.Module)whose__init__(self, rngs)builds annnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=rngs)and whose__call__(x)returnsself.linear(x). -
Wrap it:
linen_wrapped = bridge.to_linen(MyNNX). -
Run the standard Linen flow:
params = linen_wrapped.init(key, x)withkey = jax.random.PRNGKey(int(seed)); returnlinen_wrapped.apply(params, x).
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float).
Output: 1-D array of length features.
Hints
Sign in to attempt this problem and view the solution.