We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Bridge: Train Mixed Linen+NNX Model
Why this matters
During an incremental Linen → nnx migration you’ll have hybrid models: an nnx encoder feeding into a Linen head, or a Linen embed layer feeding into an nnx Transformer. The training step must update both halves in a single backward pass.
The pedagogical trick: we don’t need bridge wrappers for this. The
fundamental fact is that gradients flow through plain JAX. As
long as you can express the forward as a pure function of two
parameter pytrees (one nnx-derived, one Linen-derived), jax.grad
differentiates through both and you apply SGD to both.
The simulation pattern
Rather than using bridge.ToNNX or bridge.ToLinen (which work but
add wrapper-state ceremony), we’ll simulate interop: pull the
nnx Linear’s params out as a plain dict, and treat them as a peer
parameter pytree alongside the Linen params dict.
nnx_linear = nnx.Linear(in_f, hidden, rngs=rngs)
nnx_params = {
"kernel": nnx_linear.kernel.value,
"bias": nnx_linear.bias.value,
}
linen_dense = nn.Dense(features=out_f)
linen_params = linen_dense.init(jax.random.PRNGKey(seed + 1),
jnp.zeros((hidden,)))
The forward then composes them directly:
def forward(nnx_params, linen_params, x):
h = x @ nnx_params["kernel"] + nnx_params["bias"]
h = jax.nn.relu(h)
out = linen_dense.apply(linen_params, h)
return out
The Linen module is captured as a closure; only its params are
treated as differentiable. The nnx side is “manually unrolled” —
we use jnp.dot + bias instead of nnx_linear(x) because we’re
differentiating w.r.t. the dict of arrays, not the nnx Module.
The training step
def loss_fn(nnx_params, linen_params):
return jnp.mean((forward(nnx_params, linen_params, x) - y) ** 2)
loss_before = loss_fn(nnx_params, linen_params)
g_nnx, g_linen = jax.grad(loss_fn, argnums=(0, 1))(nnx_params, linen_params)
new_nnx = jax.tree_util.tree_map(lambda p, g: p - lr*g, nnx_params, g_nnx)
new_linen = jax.tree_util.tree_map(lambda p, g: p - lr*g, linen_params, g_linen)
loss_after = loss_fn(new_nnx, new_linen)
jax.grad(..., argnums=(0, 1)) returns a tuple of gradient
pytrees — one per input. SGD update is p - lr * g per leaf, applied
via tree_map to both pytrees.
Why this teaches the right thing
The conceptual point of “training a mixed Linen+nnx model” is not “use the bridge wrapper”. It’s:
-
Both frameworks expose params as plain
jax.Arraypytrees. -
jax.graddifferentiates through plain JAX code. - Therefore, you can compose params from both frameworks in one function and train them together.
The bridge wrappers are sugar; the underlying machinery is pure JAX composition. By simulating the interop, the user sees this machinery directly.
Common pitfalls
-
Forgetting
argnums=(0, 1). Default isargnums=0— you’d get only the first gradient. -
Capturing
linen_denseas a non-pure value inside the loss. It’s fine here becauselinen_denseis a stateless Linen Module class instance; onlylinen_paramscarries trainable state. - Updating one half but not the other. Both gradient pytrees must be applied; otherwise the loss won’t decrease as expected.
-
Using the nnx Module directly inside
loss_fn. The nnx Module’s params are wrapped innnx.Param;jax.gradover the Module is awkward. Pull out arrays as a plain dict first.
Problem
Write mixed_train_step(seed, x, y, lr) with hidden = 4:
-
Build
nnx_linear = nnx.Linear(in_features=x.shape[-1], out_features=hidden, rngs=nnx.Rngs(int(seed))). Extractnnx_params = {"kernel": ..., "bias": ...}. -
Build
linen_dense = nn.Dense(features=int(y.shape[-1])). Init viajax.random.PRNGKey(int(seed) + 1)on a dummy hidden vector to getlinen_params. -
Define
forward(nnx_params, linen_params, x):nnx_linear-style affine, ReLU, thenlinen_dense.apply(linen_params, h). -
Define
loss_fn(nnx_params, linen_params) = jnp.mean((forward(...) - y) ** 2). -
Compute
loss_before. Usejax.grad(loss_fn, argnums=(0, 1))to get(g_nnx, g_linen). SGD update both withlr. Computeloss_afteron the new params. -
Return
jnp.array([float(loss_before), float(loss_after)]).
Loss must decrease.
Inputs:
-
seed: int (passed as float). -
x,y: 1-D JAX arrays. -
lr: float (passed as float).
Output: length-2 array [loss_before, loss_after].
Hints
Sign in to attempt this problem and view the solution.