hard primitives

NNX Bridge: Train Mixed Linen+NNX Model

Why this matters

During an incremental Linen → nnx migration you’ll have hybrid models: an nnx encoder feeding into a Linen head, or a Linen embed layer feeding into an nnx Transformer. The training step must update both halves in a single backward pass.

The pedagogical trick: we don’t need bridge wrappers for this. The fundamental fact is that gradients flow through plain JAX. As long as you can express the forward as a pure function of two parameter pytrees (one nnx-derived, one Linen-derived), jax.grad differentiates through both and you apply SGD to both.

The simulation pattern

Rather than using bridge.ToNNX or bridge.ToLinen (which work but add wrapper-state ceremony), we’ll simulate interop: pull the nnx Linear’s params out as a plain dict, and treat them as a peer parameter pytree alongside the Linen params dict.

nnx_linear = nnx.Linear(in_f, hidden, rngs=rngs)
nnx_params = {
    "kernel": nnx_linear.kernel.value,
    "bias":   nnx_linear.bias.value,
}

linen_dense  = nn.Dense(features=out_f)
linen_params = linen_dense.init(jax.random.PRNGKey(seed + 1),
                                jnp.zeros((hidden,)))

The forward then composes them directly:

def forward(nnx_params, linen_params, x):
    h   = x @ nnx_params["kernel"] + nnx_params["bias"]
    h   = jax.nn.relu(h)
    out = linen_dense.apply(linen_params, h)
    return out

The Linen module is captured as a closure; only its params are treated as differentiable. The nnx side is “manually unrolled” — we use jnp.dot + bias instead of nnx_linear(x) because we’re differentiating w.r.t. the dict of arrays, not the nnx Module.

The training step

def loss_fn(nnx_params, linen_params):
    return jnp.mean((forward(nnx_params, linen_params, x) - y) ** 2)

loss_before = loss_fn(nnx_params, linen_params)
g_nnx, g_linen = jax.grad(loss_fn, argnums=(0, 1))(nnx_params, linen_params)

new_nnx   = jax.tree_util.tree_map(lambda p, g: p - lr*g, nnx_params,   g_nnx)
new_linen = jax.tree_util.tree_map(lambda p, g: p - lr*g, linen_params, g_linen)

loss_after = loss_fn(new_nnx, new_linen)

jax.grad(..., argnums=(0, 1)) returns a tuple of gradient pytrees — one per input. SGD update is p - lr * g per leaf, applied via tree_map to both pytrees.

Why this teaches the right thing

The conceptual point of “training a mixed Linen+nnx model” is not “use the bridge wrapper”. It’s:

  1. Both frameworks expose params as plain jax.Array pytrees.
  2. jax.grad differentiates through plain JAX code.
  3. Therefore, you can compose params from both frameworks in one function and train them together.

The bridge wrappers are sugar; the underlying machinery is pure JAX composition. By simulating the interop, the user sees this machinery directly.

Common pitfalls

  • Forgetting argnums=(0, 1). Default is argnums=0 — you’d get only the first gradient.
  • Capturing linen_dense as a non-pure value inside the loss. It’s fine here because linen_dense is a stateless Linen Module class instance; only linen_params carries trainable state.
  • Updating one half but not the other. Both gradient pytrees must be applied; otherwise the loss won’t decrease as expected.
  • Using the nnx Module directly inside loss_fn. The nnx Module’s params are wrapped in nnx.Param; jax.grad over the Module is awkward. Pull out arrays as a plain dict first.

Problem

Write mixed_train_step(seed, x, y, lr) with hidden = 4:

  1. Build nnx_linear = nnx.Linear(in_features=x.shape[-1], out_features=hidden, rngs=nnx.Rngs(int(seed))). Extract nnx_params = {"kernel": ..., "bias": ...}.
  2. Build linen_dense = nn.Dense(features=int(y.shape[-1])). Init via jax.random.PRNGKey(int(seed) + 1) on a dummy hidden vector to get linen_params.
  3. Define forward(nnx_params, linen_params, x): nnx_linear-style affine, ReLU, then linen_dense.apply(linen_params, h).
  4. Define loss_fn(nnx_params, linen_params) = jnp.mean((forward(...) - y) ** 2).
  5. Compute loss_before. Use jax.grad(loss_fn, argnums=(0, 1)) to get (g_nnx, g_linen). SGD update both with lr. Compute loss_after on the new params.
  6. Return jnp.array([float(loss_before), float(loss_after)]).

Loss must decrease.

Inputs:

  • seed: int (passed as float).
  • x, y: 1-D JAX arrays.
  • lr: float (passed as float).

Output: length-2 array [loss_before, loss_after].

Hints

flax nnx bridge interop linen training

Sign in to attempt this problem and view the solution.