hard primitives

Gumbel-Softmax

Why this matters

Categorical sampling is not differentiable โ€” the argmax operation has zero gradient almost everywhere. The Gumbel-Softmax (also called Concrete distribution) is a differentiable approximation: instead of sampling a one-hot vector, it produces a soft probability vector that concentrates near one-hot as temperature โ†’ 0. This enables gradient flow through discrete choices, which is essential for training discrete VAEs, VQ-VAE codebooks, and differentiable RL action selection.

The trick exploits the Gumbel-max theorem: adding independent Gumbel(0, 1) noise to logits and taking the argmax is equivalent to sampling from the categorical distribution. Replacing argmax with softmax gives the differentiable relaxation.

Worked mini-example

import jax, jax.numpy as jnp

key = jax.random.PRNGKey(0)
logits = jnp.array([1.0, 2.0, 3.0])
temperature = 0.5

g = jax.random.gumbel(key, logits.shape)   # Gumbel(0,1) noise
y = jax.nn.softmax((logits + g) / temperature)
# โ†’ soft approximation of a one-hot categorical sample; sums to 1
# At low temperature this is close to one-hot; at high temperature, near uniform

Common pitfalls

  • Temperature controls sharpness: low temperature (e.g. 0.1) โ†’ nearly one-hot output; high temperature (e.g. 10.0) โ†’ nearly uniform output. Never pass temperature=0 (division by zero).
  • jax.random.gumbel(key, shape) draws Gumbel(0, 1) directly โ€” no need to derive it from uniform samples manually.
  • Forgetting the temperature divisor: (logits + g) / temperature โ€” the /temperature is essential.
  • Output sums to 1 (it is a probability vector), not the raw pre-softmax values.

Problem

Implement gumbel_softmax(seed, logits, temperature) that returns a differentiable soft categorical sample.

logits is a 1-D array of shape (K,). Return a 1-D probability array of shape (K,) summing to 1.

One illustrative example (not from the test set):

  • gumbel_softmax(0, jnp.array([0.0, 0.0]), 1.0) returns a 1-D float32 array of shape (2,) summing to 1, deterministic for seed 0.

Hints

jax gumbel-softmax categorical

Sign in to attempt this problem and view the solution.