medium primitives

float8 Round-Trip

Why this matters

Float8 (fp8) is the cutting-edge low-precision format on H100 and Blackwell hardware. Faster matmuls, half the memory of bfloat16, at the cost of significant precision loss. There are two standard variants:

  • e4m3fn — 4-bit exponent, 3-bit mantissa, finite-only (FN). Used for weights. Narrower dynamic range but more precision bits.
  • e5m2 — 5-bit exponent, 2-bit mantissa. Used for activations (and gradients). Wider dynamic range, less precision.

The typical pattern is: cast to fp8 for storage / matmul, cast back to float32 for accumulation and further ops.

import jax.numpy as jnp

x_fp8 = x.astype(jnp.float8_e4m3fn)
result = x_fp8.astype(jnp.float32)   # round-trip

Worked mini-example

import jax.numpy as jnp

x = jnp.array([1.0, 2.0, 3.0])
x_fp8 = x.astype(jnp.float8_e4m3fn)
x_back = x_fp8.astype(jnp.float32)
print(x_back)  # [1.0, 2.0, 3.0]  (these fit exactly in e4m3fn)

Common pitfalls

  • Large values saturate: fp8 e4m3fn max is 448.0 — inputs beyond that clip to NaN (finite-only variant uses NaN for overflow, not infinity).
  • Precision loss for irregular values: 3-bit mantissa means only ~1 decimal digit of precision — values round to the nearest representable fp8 value.
  • CPU simulation: most hardware doesn’t have native fp8; JAX simulates the format on CPU. The round-trip precision loss is real and matches H100 behavior.
  • Don’t accumulate in fp8: always upcast to float32 after matmuls.

Problem

Implement fp8_roundtrip(x):

  1. Cast x to jnp.float8_e4m3fn.
  2. Cast the result back to jnp.float32.
  3. Return.
  • x: 1-D JAX float32 array.

Returns: 1-D float32 array — round-tripped through fp8 (lossy for most values).

Examples (not from the test set):

  • fp8_roundtrip(jnp.array([1.0]))[1.0] (exact in e4m3fn)
  • fp8_roundtrip(jnp.array([100.0]))[96.0] (100.0 rounds to 96.0 in e4m3fn)

Hints

jax fp8 mixed-precision

Sign in to attempt this problem and view the solution.