easy primitives

int8 Dequantization

Why this matters

Dequantization is the inverse of affine quantization. Given quantized integer values q, scale, and zero_point, it recovers an approximation of the original floats:

x_hat = (q - zero_point) * scale

This is used in every quantized inference pipeline: weights are stored as int8 and dequantized before (or during) the matmul. It also appears in quantization-aware training (QAT) to simulate int8 precision during the forward pass while keeping float32 for gradients.

Note: dequantization is lossy โ€” rounding during quantization destroys information that cannot be recovered. x_hat is an approximation of the original x.

Worked mini-example

import jax.numpy as jnp

q = jnp.array([0.0, 1.0, 2.0])
scale, zero_point = 0.5, 0.0

x_hat = (q - zero_point) * scale
# = [0.0, 0.5, 1.0]

Common pitfalls

  • Order matters: subtract zero_point FIRST, then multiply by scale. (q * scale) - zero_point * scale is algebraically equivalent but deviates when numerical precision is involved.
  • Lossy: the result is approximate because quantization rounds.
  • zero_point can be negative: e.g., zero_point = -5 shifts the quantized range up by 5 before scaling.

Problem

Implement affine_dequantize(q, scale, zero_point):

  1. Subtract zero_point from q.
  2. Multiply by scale.
  3. Return the result.
  • q: 1-D JAX float array (integer values stored as floats).
  • scale, zero_point: Python floats (scalars).

Returns: 1-D float array โ€” recovered float approximation.

Examples (not from the test set):

  • affine_dequantize(jnp.array([0.0, 1.0, 2.0]), 0.5, 0.0) โ†’ [0.0, 0.5, 1.0]
  • affine_dequantize(jnp.array([5.0]), 0.1, -5.0) โ†’ [1.0]

Hints

jax quantization dequantize

Sign in to attempt this problem and view the solution.