We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Mixed-Precision Training Step
Why this matters
On modern accelerators (A100/H100/TPU), bfloat16 and float16 matmuls
are 2× to 8× faster than float32 and use half the memory. Every
serious training run uses some form of mixed precision.
But you can’t just cast everything to bf16 — gradients become small, parameter updates underflow, and training diverges. The standard recipe:
- Weights stored in fp32 (the “master copy”) — accumulators and optimizer state stay in high precision.
- Forward + backward in bf16 — the matmuls and activations, which dominate runtime, run fast.
- Loss scaling for fp16 (less needed for bf16; we still use it for teaching purposes) — multiply the loss before backprop so small grads stay representable; divide grads back before the optimizer.
-
Optimizer step in fp32 — small updates
lr * grad ≈ 1e-7need fp32’s mantissa precision; in bf16 they’d round to zero.
The recipe
def loss_fn(params):
# 1. Cast inputs into low precision INSIDE the loss_fn
params_bf16 = jax.tree_util.tree_map(
lambda p: p.astype(jnp.bfloat16), params
)
x_bf16 = x.astype(jnp.bfloat16)
y_bf16 = y.astype(jnp.bfloat16)
# 2. Forward + loss in bf16
preds = state.apply_fn({"params": params_bf16}, x_bf16).reshape(-1)
loss_bf16 = jnp.mean((preds - y_bf16) ** 2)
# 3. Scale the loss in fp32 to protect tiny grads
loss_scaled = loss_bf16.astype(jnp.float32) * loss_scale
return loss_scaled
# Grads come back in (possibly) bf16 dtype — cast and unscale
grads = jax.grad(loss_fn)(state.params)
grads = jax.tree_util.tree_map(
lambda g: g.astype(jnp.float32) / loss_scale, grads
)
# Optimizer update happens in fp32 with fp32 params and fp32 grads
state = state.apply_gradients(grads=grads)
The flow is: fp32 → bf16 → bf16 → fp32. The bf16 zone is between
the cast inside loss_fn and the cast back when handling grads.
Why scale the loss?
bf16 / fp16 has limited dynamic range. If grads are around 1e-5 they’re
representable in bf16, but if they’re around 1e-9 they round to zero —
“underflow.” Multiplying the loss by S multiplies grads by S (chain
rule). Then we divide by S after grads are computed. The grads we see
are the same in expectation, but they survived the bf16 round-trip.
With bf16 specifically, dynamic range is similar to fp32, so loss
scaling matters LESS than for fp16. Real frameworks often use loss_scale = 1.0
for bf16. We still teach the pattern because the same code is needed for
fp16 (e.g. on older GPUs or TPU v3).
Why the cast inside the loss_fn, not outside?
Two reasons:
-
The fp32 master copy of params lives outside
loss_fn; only inside do we want bf16. Afterjax.gradfinishes, we’re back to fp32 land. -
apply_gradientsappliesgradtoparams, expecting both have the same dtype. If params are fp32 and grads were left as bf16, you’d get dtype mismatches or silent precision loss.
Common pitfalls
-
Casting params to bf16 BEFORE
loss_fn: then the master copy is bf16 too — defeats the entire point. -
Forgetting to unscale the grads: the optimizer would step
S×too large. -
Not casting
xandyto bf16: matmul will broadcast-promote them to fp32, so the matmul still runs in fp32 — no speed-up. -
Casting the loss back to fp32 BEFORE multiplying by
loss_scale: this is fine and what the recipe does. Order matters less here than for1/loss_scaleof grads (which MUST be in fp32).
Problem
Implement a single mixed-precision training step:
-
Build a
TrainStatewith a tinynn.Dense(1)andoptax.sgd(lr). -
Inside
loss_fn: cast params, x, y to bf16; do MSE forward in bf16; scale loss byloss_scale(cast to fp32 first). -
Compute grads with
jax.grad. -
Cast grads to fp32 and divide by
loss_scale. -
apply_gradients. -
Recompute the un-scaled MSE in bf16 (cast back to fp32) at the
new params. Return as 1-D
(1,).
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D). -
y: 1-D(N,). -
lr: float. -
loss_scale: float — typically a power of 2 (e.g., 64, 128, 256).
Output: 1-D (1,) — [final_loss_after_step_in_fp32].
Hints
Sign in to attempt this problem and view the solution.