We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
int8 Dequantization
Why this matters
Dequantization is the inverse of affine quantization. Given quantized
integer values q, scale, and zero_point, it recovers an approximation
of the original floats:
x_hat = (q - zero_point) * scale
This is used in every quantized inference pipeline: weights are stored as int8 and dequantized before (or during) the matmul. It also appears in quantization-aware training (QAT) to simulate int8 precision during the forward pass while keeping float32 for gradients.
Note: dequantization is lossy โ rounding during quantization destroys
information that cannot be recovered. x_hat is an approximation of the
original x.
Worked mini-example
import jax.numpy as jnp
q = jnp.array([0.0, 1.0, 2.0])
scale, zero_point = 0.5, 0.0
x_hat = (q - zero_point) * scale
# = [0.0, 0.5, 1.0]
Common pitfalls
-
Order matters: subtract
zero_pointFIRST, then multiply byscale.(q * scale) - zero_point * scaleis algebraically equivalent but deviates when numerical precision is involved. - Lossy: the result is approximate because quantization rounds.
-
zero_point can be negative: e.g.,
zero_point = -5shifts the quantized range up by 5 before scaling.
Problem
Implement affine_dequantize(q, scale, zero_point):
-
Subtract
zero_pointfromq. -
Multiply by
scale. - Return the result.
-
q: 1-D JAX float array (integer values stored as floats). -
scale,zero_point: Python floats (scalars).
Returns: 1-D float array โ recovered float approximation.
Examples (not from the test set):
-
affine_dequantize(jnp.array([0.0, 1.0, 2.0]), 0.5, 0.0)โ[0.0, 0.5, 1.0] -
affine_dequantize(jnp.array([5.0]), 0.1, -5.0)โ[1.0]
Hints
Sign in to attempt this problem and view the solution.