We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
int8 Affine Quantization
Why this matters
Affine (asymmetric) quantization maps a floating-point tensor to 8-bit
integers using a scale and a zero_point:
q = round(x / scale) + zero_point, clipped to [-128, 127]
This is the core op in post-training quantization (PTQ), quantization-aware
training (QAT), and deploying models to int8 inference engines (TensorRT,
TFLite, ONNX Runtime). The scale sets the step size; zero_point shifts
the zero so it is exactly representable — important for sparse activations
like ReLU output.
Worked mini-example
import jax.numpy as jnp
x = jnp.array([0.0, 0.5, 1.0])
scale, zero_point = 0.5, 0.0
quantized = jnp.round(x / scale) + zero_point
# round([0, 1, 2]) + 0 = [0, 1, 2]
clipped = jnp.clip(quantized, -128.0, 127.0)
# → [0.0, 1.0, 2.0]
Common pitfalls
-
Forgetting to clip: out-of-range values silently wrap or saturate
depending on downstream ops — always clip to
[-128, 127]. -
Rounding ties go to even (banker’s rounding):
jnp.round(0.5)→0.0, not1.0. This is IEEE 754 default and matches most hardware. - scale must be positive: negative or zero scale produces garbage quants.
- Clip range is int8, not int7: the signed 8-bit range is [-128, 127], not [-127, 127].
Problem
Implement affine_quantize(x, scale, zero_point):
-
Divide
xbyscale. - Round to nearest integer (banker’s rounding).
-
Add
zero_point. -
Clip to
[-128.0, 127.0]. -
Return the result as a float array (same shape as
x).
-
x: 1-D JAX float array. -
scale,zero_point: Python floats (scalars).
Returns: 1-D float array — quantized integer values stored as floats.
Examples (not from the test set):
-
affine_quantize(jnp.array([0.0, 0.5, 1.0]), 0.5, 0.0)→[0.0, 1.0, 2.0] -
affine_quantize(jnp.array([300.0]), 1.0, 0.0)→[127.0](clipped)
Hints
Sign in to attempt this problem and view the solution.