We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Mixed-Precision Step
Why this matters
On modern accelerators (A100/H100/TPU), bfloat16 matmuls are 2-8x
faster than float32 and use half the memory. Every serious training
run uses some form of mixed precision.
But you can’t just cast everything to bf16 — gradients become small, parameter updates underflow, and training diverges. The standard recipe:
- Master weights in fp32 — accumulators and optimizer state stay in high precision.
- Forward + backward in bf16 — the matmuls and activations, which dominate runtime, run fast.
- Loss scaling — multiply the loss before backprop so small grads stay representable; divide grads back before the optimizer. For bf16 specifically loss scaling matters less than for fp16, but we still teach the pattern because the same code is needed for fp16 on older GPUs.
-
Optimizer step in fp32 — small updates
lr * grad ≈ 1e-7need fp32’s mantissa precision; in bf16 they’d round to zero.
The recipe in nnx
The fp32 master copy lives on the model. To do the forward in bf16
without permanently downcasting the params, we use
nnx.split to peel off the params state, cast the leaves to bf16
inside loss_fn, then nnx.merge them with the static graphdef
into a temporary bf16 model. Outside loss_fn, the original model
object is still fp32.
def loss_fn(model, x, y):
# 1. Split off the Param state and cast it to bf16.
gdef, state = nnx.split(model, nnx.Param)
state_bf16 = jax.tree_util.tree_map(
lambda p: p.astype(jnp.bfloat16), state
)
model_bf16 = nnx.merge(gdef, state_bf16)
# 2. Forward + loss in bf16.
x_bf16 = x.astype(jnp.bfloat16)
y_bf16 = y.astype(jnp.bfloat16)
pred = model_bf16(x_bf16)
loss_bf16 = jnp.mean((pred - y_bf16) ** 2)
# 3. Scale the loss in fp32 so tiny grads survive bf16 backprop.
return loss_bf16.astype(jnp.float32) * loss_scale
_, grads = nnx.value_and_grad(loss_fn)(model, x, y)
# 4. Cast grads to fp32 (defensive) and unscale.
grads = jax.tree_util.tree_map(
lambda g: g.astype(jnp.float32) / loss_scale, grads
)
# 5. Optimizer step in fp32 with the fp32 master params.
optimizer.update(model, grads)
The flow is fp32 → bf16 → bf16 → fp32. The bf16 zone is bracketed
by the cast inside loss_fn and the cast back when handling grads.
Why nnx.split + nnx.merge?
A direct approach — model.kernel = self.kernel.astype(jnp.bfloat16) —
would mutate the master copy. We DON’T want that. Splitting gives us
a frozen state pytree we can cast freely; merging produces a new
temporary model that wraps the bf16 leaves. The original model
object is untouched.
Why scale the loss?
bf16 has limited dynamic range. If grads are around 1e-5 they’re
representable in bf16, but if they’re around 1e-9 they round to
zero — “underflow.” Multiplying the loss by S multiplies grads by
S (chain rule). Dividing by S after grads are computed gives the
same expected value, but the grads survived the bf16 round-trip.
With bf16 specifically, dynamic range is similar to fp32, so loss
scaling matters LESS than for fp16. Real frameworks often use
loss_scale = 1.0 for bf16. Test 3 below uses loss_scale = 1.0
to demonstrate the no-scale case.
Common pitfalls
-
Casting the model permanently to bf16. Then the master copy is
bf16 too — defeats the entire point. Always use split/merge or
cast inside
loss_fnonly. -
Forgetting to unscale grads. The optimizer would step
S×too large. -
Not casting
xandyto bf16. Matmul will broadcast-promote them to fp32, so the matmul still runs in fp32 — no speed-up. -
Casting grads to bf16. They come back as fp32 because the loss
was fp32 (we multiplied by
loss_scalein fp32). The defensiveastype(jnp.float32)is a no-op but useful for documentation.
Problem
Implement mixed_precision_step(seed, x, y, lr, loss_scale):
-
Build
model = nnx.Linear(x.shape[-1], y.shape[-1], rngs=...)andoptimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param). The model’s params are fp32 — that’s the master copy. -
Inside
loss_fn(model, x, y):-
Split off the
nnx.Paramstate withnnx.split(model, nnx.Param). -
Cast the state’s leaves to bf16 via
tree_map. -
Re-merge into
model_bf16 = nnx.merge(gdef, state_bf16). -
Forward + MSE in bf16, then
loss_bf16.astype(jnp.float32) * loss_scale.
-
Split off the
-
Compute grads, then
tree_map(lambda g: g.astype(jnp.float32) / loss_scale, grads). -
optimizer.update(model, grads). -
Recompute the un-scaled MSE in bf16 (cast back to fp32) at the
new params. Return as 1-D
(1,).
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D_in). -
y: 2-D(N, D_out). -
lr: float. -
loss_scale: float — typically a power of 2.
Output: 1-D (1,) — [final_loss_after_step_in_fp32].
Hints
Sign in to attempt this problem and view the solution.