We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Bridge: Call Linen From NNX
Why this matters
Real codebases don’t migrate overnight. You’ll have a fresh nnx model
that wants to use a flax.linen layer that already exists — maybe a
custom nn.Module shipped by a research team, or the official Linen
MultiHeadDotProductAttention you don’t feel like rewriting yet.
Flax 0.11+ ships flax.nnx.bridge for exactly this. bridge.ToNNX(linen_module)
wraps a Linen nn.Module so it acts like an nnx.Module. You can then
drop it into an nnx.Module as an attribute and call it like any other
submodule.
The pattern is the foundation of incremental migration: keep your old
Linen layers, build new code in nnx, glue them with bridge.ToNNX.
API: bridge.ToNNX
import flax.linen as nn
from flax import nnx
from flax.nnx import bridge
linen_dense = nn.Dense(features=4) # plain Linen module
wrapped = bridge.ToNNX(linen_dense, rngs=nnx.Rngs(0))
# `wrapped` is now an nnx.Module — its params live as nnx Variables.
The wrapper hides Linen’s init/apply ceremony. But since Linen
layers are lazy (they need a sample input to materialize parameters),
you must call bridge.lazy_init(model, x) once before the first real
forward pass. After that, the wrapper behaves like a normal nnx module:
state is reachable via nnx.split, gradients flow, you can call it
inside nnx.jit, etc.
Worked sketch
class HybridModel(nnx.Module):
def __init__(self, features, rngs):
# The Linen module is wrapped and stored as an nnx attribute.
self.linen_dense = bridge.ToNNX(nn.Dense(features=features), rngs=rngs)
def __call__(self, x):
return self.linen_dense(x)
model = HybridModel(features=4, rngs=nnx.Rngs(0))
bridge.lazy_init(model, jnp.ones((3,))) # materialize Linen params
y = model(jnp.array([1.0, 2.0, 3.0]))
The inner Linen nn.Dense declares its kernel/bias via self.param at
init time. lazy_init runs the trace once with the sample input and
captures the resulting params dict into nnx Variables on the wrapper.
Why lazy_init?
Linen modules are defined-by-shape: nn.Dense(features=4) doesn’t
know the input dim until you trace it on real data. nnx modules, by
contrast, take both in_features and out_features at construction.
The bridge bridges this gap by deferring init until you call
lazy_init (or implicitly on the first real forward).
Where this fits in production
-
Reusing a Linen
nn.MultiHeadDotProductAttentioninside a new nnx Transformer. - Calling a third-party Linen layer (e.g., a tokenizer-side embedding from Flaxformer) without rewriting it.
- Gradual port: convert leaf modules to nnx, leave higher-level Linen orchestration for later.
Common pitfalls
-
Forgetting
lazy_init. Without it, the wrapper has no params yet and the first call raises (or silently re-inits and discards prior state). Alwayslazy_init(model, x)before the first forward. -
Passing
rngs=nnx.Rngs(seed)twice. The wrapper holds the rngs; don’t also try to feed Linen-stylerngsdicts at apply time. -
Calling
nnx.splitbeforelazy_init. State doesn’t exist yet — split would return an empty/incomplete pytree. -
Re-
lazy_init-ing. Re-running it on a fresh sample input reinitializes — you lose any trained params.
Problem
Write linen_inside_nnx(seed, x, features):
-
Define
MyNNXWithLinen(nnx.Module)whose__init__(features, rngs)storesself.linen_dense = bridge.ToNNX(nn.Dense(features=features), rngs=rngs). -
__call__(x)returnsself.linen_dense(x). -
In the entry function: cast
featuresto int, buildnnx.Rngs(int(seed)), instantiate the model, callbridge.lazy_init(model, x), then returnmodel(x).
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float).
Output: 1-D array of length features.
Hints
Sign in to attempt this problem and view the solution.