We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
bfloat16 Mixed Precision
Why this matters
bfloat16 is the standard mixed-precision format on TPUs and modern GPUs.
It shares the same 8-bit exponent as float32 โ so its dynamic range is
identical โ but uses only an 8-bit mantissa instead of 23. The result:
matmuls run roughly 2ร faster and use half the memory, at the cost of
slightly lower precision.
x_bf = x.astype(jnp.bfloat16)
y_bf = y.astype(jnp.bfloat16)
result = jnp.dot(x_bf, y_bf).astype(jnp.float32)
Worked mini-example
import jax.numpy as jnp
x = jnp.array([1.0, 2.0])
y = jnp.array([3.0, 4.0])
dot_bf16 = jnp.dot(x.astype(jnp.bfloat16),
y.astype(jnp.bfloat16)).astype(jnp.float32)
print(dot_bf16) # 11.0 (1*3 + 2*4)
Common pitfalls
- Precision loss: bfloat16 has only 8 mantissa bits (~2.4 decimal digits). Values that differ in the 3rd+ decimal digit may round to the same bfloat16 representant; the result is then less precise than float32.
- Keeping the result in bfloat16: always cast back to float32 before returning. JSON serialisation (and many downstream ops) expect float32.
-
Float64 inputs on CPU: JAX silently downcasts to float32 when
x64=False(the default). Cast explicitly to bfloat16 anyway.
Problem
Implement bfloat16_dot(x, y) that:
-
Casts both
xandytojnp.bfloat16. -
Computes
jnp.dotin bfloat16. -
Casts the result back to
jnp.float32and returns it.
-
x,y: 1-D JAX arrays of the same length.
Returns: scalar (float32).
Examples (not from the test set):
-
bfloat16_dot(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0]))โ11.0
Hints
jax
bfloat16
mixed-precision
Sign in to attempt this problem and view the solution.