hard primitives

NNX Bridge: Call Linen From NNX

Why this matters

Real codebases don’t migrate overnight. You’ll have a fresh nnx model that wants to use a flax.linen layer that already exists — maybe a custom nn.Module shipped by a research team, or the official Linen MultiHeadDotProductAttention you don’t feel like rewriting yet.

Flax 0.11+ ships flax.nnx.bridge for exactly this. bridge.ToNNX(linen_module) wraps a Linen nn.Module so it acts like an nnx.Module. You can then drop it into an nnx.Module as an attribute and call it like any other submodule.

The pattern is the foundation of incremental migration: keep your old Linen layers, build new code in nnx, glue them with bridge.ToNNX.

API: bridge.ToNNX

import flax.linen as nn
from flax import nnx
from flax.nnx import bridge

linen_dense = nn.Dense(features=4)              # plain Linen module
wrapped = bridge.ToNNX(linen_dense, rngs=nnx.Rngs(0))
# `wrapped` is now an nnx.Module — its params live as nnx Variables.

The wrapper hides Linen’s init/apply ceremony. But since Linen layers are lazy (they need a sample input to materialize parameters), you must call bridge.lazy_init(model, x) once before the first real forward pass. After that, the wrapper behaves like a normal nnx module: state is reachable via nnx.split, gradients flow, you can call it inside nnx.jit, etc.

Worked sketch

class HybridModel(nnx.Module):
    def __init__(self, features, rngs):
        # The Linen module is wrapped and stored as an nnx attribute.
        self.linen_dense = bridge.ToNNX(nn.Dense(features=features), rngs=rngs)

    def __call__(self, x):
        return self.linen_dense(x)

model = HybridModel(features=4, rngs=nnx.Rngs(0))
bridge.lazy_init(model, jnp.ones((3,)))   # materialize Linen params
y = model(jnp.array([1.0, 2.0, 3.0]))

The inner Linen nn.Dense declares its kernel/bias via self.param at init time. lazy_init runs the trace once with the sample input and captures the resulting params dict into nnx Variables on the wrapper.

Why lazy_init?

Linen modules are defined-by-shape: nn.Dense(features=4) doesn’t know the input dim until you trace it on real data. nnx modules, by contrast, take both in_features and out_features at construction. The bridge bridges this gap by deferring init until you call lazy_init (or implicitly on the first real forward).

Where this fits in production

  • Reusing a Linen nn.MultiHeadDotProductAttention inside a new nnx Transformer.
  • Calling a third-party Linen layer (e.g., a tokenizer-side embedding from Flaxformer) without rewriting it.
  • Gradual port: convert leaf modules to nnx, leave higher-level Linen orchestration for later.

Common pitfalls

  • Forgetting lazy_init. Without it, the wrapper has no params yet and the first call raises (or silently re-inits and discards prior state). Always lazy_init(model, x) before the first forward.
  • Passing rngs=nnx.Rngs(seed) twice. The wrapper holds the rngs; don’t also try to feed Linen-style rngs dicts at apply time.
  • Calling nnx.split before lazy_init. State doesn’t exist yet — split would return an empty/incomplete pytree.
  • Re-lazy_init-ing. Re-running it on a fresh sample input reinitializes — you lose any trained params.

Problem

Write linen_inside_nnx(seed, x, features):

  1. Define MyNNXWithLinen(nnx.Module) whose __init__(features, rngs) stores self.linen_dense = bridge.ToNNX(nn.Dense(features=features), rngs=rngs).
  2. __call__(x) returns self.linen_dense(x).
  3. In the entry function: cast features to int, build nnx.Rngs(int(seed)), instantiate the model, call bridge.lazy_init(model, x), then return model(x).

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).

Output: 1-D array of length features.

Hints

flax nnx bridge interop linen

Sign in to attempt this problem and view the solution.