We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Straight-Through Estimator
Why this matters
Discrete operations β sign, argmax, hard thresholding, vector
quantization β have zero gradient almost everywhere. That makes them
useless for backprop: the optimization landscape is flat everywhere the
discrete boundary isnβt crossed.
The Straight-Through Estimator (STE) fixes this by defining a surrogate gradient: forward pass uses the discrete op, backward pass uses the identity (gradient = 1). The discrete output reaches the loss, but the gradient signal still flows back as if the discrete op were just a pass-through.
STE powers:
- VQ-VAE (vector-quantization backward pass).
- BinaryConnect / BinaryNet (binary weight training).
- BinaryNet inference layers in edge models.
- Gumbel-softmax hard samples (forward hard, backward soft).
Worked mini-example
import jax
import jax.numpy as jnp
# NaΓ―ve approach β zero gradient
grad_sign = jax.grad(lambda x: jnp.sign(x).sum())
print(grad_sign(jnp.array([0.5]))) # [0.] β useless
# STE trick: x + stop_gradient(sign(x) - x)
def ste_quantize(x):
return x + jax.lax.stop_gradient(jnp.sign(x) - x)
# Forward: x + (sign(x) - x) = sign(x) β
print(ste_quantize(jnp.array([-1.5, 0.0, 0.5, 2.0])))
# β [-1. 0. 1. 1.]
# Backward: stop_gradient zeros the inner derivative β grad = 1
grad_ste = jax.grad(lambda x: ste_quantize(x).sum())
print(grad_ste(jnp.array([0.5]))) # [1.] β useful
Common pitfalls
-
Writing
jnp.sign(x)directly β gradient is 0 almost everywhere; the surrogate trick is necessary for learning. -
The algebraic trick
x + stop_gradient(g(x) - x):-
Forward:
x + (g(x) - x) = g(x)β the discrete output. -
Backward:
stop_gradientzerosd/dx [g(x) - x], leavingd/dx [x] = 1β the identity pass-through.
-
Forward:
-
Donβt use
custom_vjpfor STE β the algebraic trick is cleaner, shorter, and composable. -
jnp.sign(0) == 0β the forward value at zero is 0, not Β±1. This is the standard convention forsign.
Problem
Implement ste_quantize(x) that returns jnp.sign(x) in the forward
pass and passes gradient through unchanged (identity) in the backward pass.
-
x: 1-D jax array.
Returns: 1-D array same shape β sign(x). Gradient (under autodiff): identity (1 everywhere).
Hints
Sign in to attempt this problem and view the solution.