medium primitives

Straight-Through Estimator

Why this matters

Discrete operations β€” sign, argmax, hard thresholding, vector quantization β€” have zero gradient almost everywhere. That makes them useless for backprop: the optimization landscape is flat everywhere the discrete boundary isn’t crossed.

The Straight-Through Estimator (STE) fixes this by defining a surrogate gradient: forward pass uses the discrete op, backward pass uses the identity (gradient = 1). The discrete output reaches the loss, but the gradient signal still flows back as if the discrete op were just a pass-through.

STE powers:

  • VQ-VAE (vector-quantization backward pass).
  • BinaryConnect / BinaryNet (binary weight training).
  • BinaryNet inference layers in edge models.
  • Gumbel-softmax hard samples (forward hard, backward soft).

Worked mini-example

import jax
import jax.numpy as jnp

# NaΓ―ve approach β€” zero gradient
grad_sign = jax.grad(lambda x: jnp.sign(x).sum())
print(grad_sign(jnp.array([0.5])))   # [0.]  ← useless

# STE trick: x + stop_gradient(sign(x) - x)
def ste_quantize(x):
    return x + jax.lax.stop_gradient(jnp.sign(x) - x)

# Forward: x + (sign(x) - x) = sign(x)  βœ“
print(ste_quantize(jnp.array([-1.5, 0.0, 0.5, 2.0])))
# β†’ [-1.  0.  1.  1.]

# Backward: stop_gradient zeros the inner derivative β†’ grad = 1
grad_ste = jax.grad(lambda x: ste_quantize(x).sum())
print(grad_ste(jnp.array([0.5])))    # [1.]  ← useful

Common pitfalls

  • Writing jnp.sign(x) directly β€” gradient is 0 almost everywhere; the surrogate trick is necessary for learning.
  • The algebraic trick x + stop_gradient(g(x) - x):
    • Forward: x + (g(x) - x) = g(x) β€” the discrete output.
    • Backward: stop_gradient zeros d/dx [g(x) - x], leaving d/dx [x] = 1 β€” the identity pass-through.
  • Don’t use custom_vjp for STE β€” the algebraic trick is cleaner, shorter, and composable.
  • jnp.sign(0) == 0 β€” the forward value at zero is 0, not Β±1. This is the standard convention for sign.

Problem

Implement ste_quantize(x) that returns jnp.sign(x) in the forward pass and passes gradient through unchanged (identity) in the backward pass.

  • x: 1-D jax array.

Returns: 1-D array same shape β€” sign(x). Gradient (under autodiff): identity (1 everywhere).

Hints

jax ste quantization

Sign in to attempt this problem and view the solution.