medium primitives

int8 Affine Quantization

Why this matters

Affine (asymmetric) quantization maps a floating-point tensor to 8-bit integers using a scale and a zero_point:

q = round(x / scale) + zero_point,  clipped to [-128, 127]

This is the core op in post-training quantization (PTQ), quantization-aware training (QAT), and deploying models to int8 inference engines (TensorRT, TFLite, ONNX Runtime). The scale sets the step size; zero_point shifts the zero so it is exactly representable — important for sparse activations like ReLU output.

Worked mini-example

import jax.numpy as jnp

x = jnp.array([0.0, 0.5, 1.0])
scale, zero_point = 0.5, 0.0

quantized = jnp.round(x / scale) + zero_point
# round([0, 1, 2]) + 0 = [0, 1, 2]
clipped = jnp.clip(quantized, -128.0, 127.0)
# → [0.0, 1.0, 2.0]

Common pitfalls

  • Forgetting to clip: out-of-range values silently wrap or saturate depending on downstream ops — always clip to [-128, 127].
  • Rounding ties go to even (banker’s rounding): jnp.round(0.5)0.0, not 1.0. This is IEEE 754 default and matches most hardware.
  • scale must be positive: negative or zero scale produces garbage quants.
  • Clip range is int8, not int7: the signed 8-bit range is [-128, 127], not [-127, 127].

Problem

Implement affine_quantize(x, scale, zero_point):

  1. Divide x by scale.
  2. Round to nearest integer (banker’s rounding).
  3. Add zero_point.
  4. Clip to [-128.0, 127.0].
  5. Return the result as a float array (same shape as x).
  • x: 1-D JAX float array.
  • scale, zero_point: Python floats (scalars).

Returns: 1-D float array — quantized integer values stored as floats.

Examples (not from the test set):

  • affine_quantize(jnp.array([0.0, 0.5, 1.0]), 0.5, 0.0)[0.0, 1.0, 2.0]
  • affine_quantize(jnp.array([300.0]), 1.0, 0.0)[127.0] (clipped)

Hints

jax quantization int8

Sign in to attempt this problem and view the solution.