hard primitives

NNX Mixed-Precision Step

Why this matters

On modern accelerators (A100/H100/TPU), bfloat16 matmuls are 2-8x faster than float32 and use half the memory. Every serious training run uses some form of mixed precision.

But you can’t just cast everything to bf16 — gradients become small, parameter updates underflow, and training diverges. The standard recipe:

  • Master weights in fp32 — accumulators and optimizer state stay in high precision.
  • Forward + backward in bf16 — the matmuls and activations, which dominate runtime, run fast.
  • Loss scaling — multiply the loss before backprop so small grads stay representable; divide grads back before the optimizer. For bf16 specifically loss scaling matters less than for fp16, but we still teach the pattern because the same code is needed for fp16 on older GPUs.
  • Optimizer step in fp32 — small updates lr * grad ≈ 1e-7 need fp32’s mantissa precision; in bf16 they’d round to zero.

The recipe in nnx

The fp32 master copy lives on the model. To do the forward in bf16 without permanently downcasting the params, we use nnx.split to peel off the params state, cast the leaves to bf16 inside loss_fn, then nnx.merge them with the static graphdef into a temporary bf16 model. Outside loss_fn, the original model object is still fp32.

def loss_fn(model, x, y):
    # 1. Split off the Param state and cast it to bf16.
    gdef, state = nnx.split(model, nnx.Param)
    state_bf16 = jax.tree_util.tree_map(
        lambda p: p.astype(jnp.bfloat16), state
    )
    model_bf16 = nnx.merge(gdef, state_bf16)

    # 2. Forward + loss in bf16.
    x_bf16 = x.astype(jnp.bfloat16)
    y_bf16 = y.astype(jnp.bfloat16)
    pred = model_bf16(x_bf16)
    loss_bf16 = jnp.mean((pred - y_bf16) ** 2)

    # 3. Scale the loss in fp32 so tiny grads survive bf16 backprop.
    return loss_bf16.astype(jnp.float32) * loss_scale

_, grads = nnx.value_and_grad(loss_fn)(model, x, y)

# 4. Cast grads to fp32 (defensive) and unscale.
grads = jax.tree_util.tree_map(
    lambda g: g.astype(jnp.float32) / loss_scale, grads
)

# 5. Optimizer step in fp32 with the fp32 master params.
optimizer.update(model, grads)

The flow is fp32 → bf16 → bf16 → fp32. The bf16 zone is bracketed by the cast inside loss_fn and the cast back when handling grads.

Why nnx.split + nnx.merge?

A direct approach — model.kernel = self.kernel.astype(jnp.bfloat16) — would mutate the master copy. We DON’T want that. Splitting gives us a frozen state pytree we can cast freely; merging produces a new temporary model that wraps the bf16 leaves. The original model object is untouched.

Why scale the loss?

bf16 has limited dynamic range. If grads are around 1e-5 they’re representable in bf16, but if they’re around 1e-9 they round to zero — “underflow.” Multiplying the loss by S multiplies grads by S (chain rule). Dividing by S after grads are computed gives the same expected value, but the grads survived the bf16 round-trip.

With bf16 specifically, dynamic range is similar to fp32, so loss scaling matters LESS than for fp16. Real frameworks often use loss_scale = 1.0 for bf16. Test 3 below uses loss_scale = 1.0 to demonstrate the no-scale case.

Common pitfalls

  • Casting the model permanently to bf16. Then the master copy is bf16 too — defeats the entire point. Always use split/merge or cast inside loss_fn only.
  • Forgetting to unscale grads. The optimizer would step too large.
  • Not casting x and y to bf16. Matmul will broadcast-promote them to fp32, so the matmul still runs in fp32 — no speed-up.
  • Casting grads to bf16. They come back as fp32 because the loss was fp32 (we multiplied by loss_scale in fp32). The defensive astype(jnp.float32) is a no-op but useful for documentation.

Problem

Implement mixed_precision_step(seed, x, y, lr, loss_scale):

  1. Build model = nnx.Linear(x.shape[-1], y.shape[-1], rngs=...) and optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param). The model’s params are fp32 — that’s the master copy.
  2. Inside loss_fn(model, x, y):
    • Split off the nnx.Param state with nnx.split(model, nnx.Param).
    • Cast the state’s leaves to bf16 via tree_map.
    • Re-merge into model_bf16 = nnx.merge(gdef, state_bf16).
    • Forward + MSE in bf16, then loss_bf16.astype(jnp.float32) * loss_scale.
  3. Compute grads, then tree_map(lambda g: g.astype(jnp.float32) / loss_scale, grads).
  4. optimizer.update(model, grads).
  5. Recompute the un-scaled MSE in bf16 (cast back to fp32) at the new params. Return as 1-D (1,).

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D_in).
  • y: 2-D (N, D_out).
  • lr: float.
  • loss_scale: float — typically a power of 2.

Output: 1-D (1,)[final_loss_after_step_in_fp32].

Hints

flax nnx mixed-precision bfloat16 training

Sign in to attempt this problem and view the solution.