hard primitives

Mixed-Precision Training Step

Why this matters

On modern accelerators (A100/H100/TPU), bfloat16 and float16 matmuls are 2× to 8× faster than float32 and use half the memory. Every serious training run uses some form of mixed precision.

But you can’t just cast everything to bf16 — gradients become small, parameter updates underflow, and training diverges. The standard recipe:

  • Weights stored in fp32 (the “master copy”) — accumulators and optimizer state stay in high precision.
  • Forward + backward in bf16 — the matmuls and activations, which dominate runtime, run fast.
  • Loss scaling for fp16 (less needed for bf16; we still use it for teaching purposes) — multiply the loss before backprop so small grads stay representable; divide grads back before the optimizer.
  • Optimizer step in fp32 — small updates lr * grad ≈ 1e-7 need fp32’s mantissa precision; in bf16 they’d round to zero.

The recipe

def loss_fn(params):
    # 1. Cast inputs into low precision INSIDE the loss_fn
    params_bf16 = jax.tree_util.tree_map(
        lambda p: p.astype(jnp.bfloat16), params
    )
    x_bf16 = x.astype(jnp.bfloat16)
    y_bf16 = y.astype(jnp.bfloat16)

    # 2. Forward + loss in bf16
    preds = state.apply_fn({"params": params_bf16}, x_bf16).reshape(-1)
    loss_bf16 = jnp.mean((preds - y_bf16) ** 2)

    # 3. Scale the loss in fp32 to protect tiny grads
    loss_scaled = loss_bf16.astype(jnp.float32) * loss_scale
    return loss_scaled

# Grads come back in (possibly) bf16 dtype — cast and unscale
grads = jax.grad(loss_fn)(state.params)
grads = jax.tree_util.tree_map(
    lambda g: g.astype(jnp.float32) / loss_scale, grads
)

# Optimizer update happens in fp32 with fp32 params and fp32 grads
state = state.apply_gradients(grads=grads)

The flow is: fp32 → bf16 → bf16 → fp32. The bf16 zone is between the cast inside loss_fn and the cast back when handling grads.

Why scale the loss?

bf16 / fp16 has limited dynamic range. If grads are around 1e-5 they’re representable in bf16, but if they’re around 1e-9 they round to zero — “underflow.” Multiplying the loss by S multiplies grads by S (chain rule). Then we divide by S after grads are computed. The grads we see are the same in expectation, but they survived the bf16 round-trip.

With bf16 specifically, dynamic range is similar to fp32, so loss scaling matters LESS than for fp16. Real frameworks often use loss_scale = 1.0 for bf16. We still teach the pattern because the same code is needed for fp16 (e.g. on older GPUs or TPU v3).

Why the cast inside the loss_fn, not outside?

Two reasons:

  1. The fp32 master copy of params lives outside loss_fn; only inside do we want bf16. After jax.grad finishes, we’re back to fp32 land.
  2. apply_gradients applies grad to params, expecting both have the same dtype. If params are fp32 and grads were left as bf16, you’d get dtype mismatches or silent precision loss.

Common pitfalls

  • Casting params to bf16 BEFORE loss_fn: then the master copy is bf16 too — defeats the entire point.
  • Forgetting to unscale the grads: the optimizer would step too large.
  • Not casting x and y to bf16: matmul will broadcast-promote them to fp32, so the matmul still runs in fp32 — no speed-up.
  • Casting the loss back to fp32 BEFORE multiplying by loss_scale: this is fine and what the recipe does. Order matters less here than for 1/loss_scale of grads (which MUST be in fp32).

Problem

Implement a single mixed-precision training step:

  1. Build a TrainState with a tiny nn.Dense(1) and optax.sgd(lr).
  2. Inside loss_fn: cast params, x, y to bf16; do MSE forward in bf16; scale loss by loss_scale (cast to fp32 first).
  3. Compute grads with jax.grad.
  4. Cast grads to fp32 and divide by loss_scale.
  5. apply_gradients.
  6. Recompute the un-scaled MSE in bf16 (cast back to fp32) at the new params. Return as 1-D (1,).

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D).
  • y: 1-D (N,).
  • lr: float.
  • loss_scale: float — typically a power of 2 (e.g., 64, 128, 256).

Output: 1-D (1,)[final_loss_after_step_in_fp32].

Hints

flax mixed-precision bfloat16 training

Sign in to attempt this problem and view the solution.