We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Gumbel-Softmax
Why this matters
Categorical sampling is not differentiable โ the argmax operation has zero gradient almost everywhere. The Gumbel-Softmax (also called Concrete distribution) is a differentiable approximation: instead of sampling a one-hot vector, it produces a soft probability vector that concentrates near one-hot as temperature โ 0. This enables gradient flow through discrete choices, which is essential for training discrete VAEs, VQ-VAE codebooks, and differentiable RL action selection.
The trick exploits the Gumbel-max theorem: adding independent Gumbel(0, 1) noise to logits and taking the argmax is equivalent to sampling from the categorical distribution. Replacing argmax with softmax gives the differentiable relaxation.
Worked mini-example
import jax, jax.numpy as jnp
key = jax.random.PRNGKey(0)
logits = jnp.array([1.0, 2.0, 3.0])
temperature = 0.5
g = jax.random.gumbel(key, logits.shape) # Gumbel(0,1) noise
y = jax.nn.softmax((logits + g) / temperature)
# โ soft approximation of a one-hot categorical sample; sums to 1
# At low temperature this is close to one-hot; at high temperature, near uniform
Common pitfalls
- Temperature controls sharpness: low temperature (e.g. 0.1) โ nearly one-hot output; high temperature (e.g. 10.0) โ nearly uniform output. Never pass temperature=0 (division by zero).
-
jax.random.gumbel(key, shape)draws Gumbel(0, 1) directly โ no need to derive it from uniform samples manually. -
Forgetting the temperature divisor:
(logits + g) / temperatureโ the/temperatureis essential. - Output sums to 1 (it is a probability vector), not the raw pre-softmax values.
Problem
Implement gumbel_softmax(seed, logits, temperature) that returns a
differentiable soft categorical sample.
logits is a 1-D array of shape (K,). Return a 1-D probability array of
shape (K,) summing to 1.
One illustrative example (not from the test set):
-
gumbel_softmax(0, jnp.array([0.0, 0.0]), 1.0)returns a 1-D float32 array of shape(2,)summing to 1, deterministic for seed 0.
Hints
Sign in to attempt this problem and view the solution.