We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Gumbel Argmax (Categorical via Trick)
Why this matters
The Gumbel-Max trick shows that sampling from a categorical distribution
Categorical(softmax(logits)) is equivalent to taking
argmax(logits + g) where g ~ Gumbel(0, 1) i.i.d. This reformulation
has two major uses in modern ML:
-
Gumbel-Softmax / Concrete distribution (Jang et al., 2017; Maddison
et al., 2017): replace
argmaxwithsoftmaxto get a differentiable relaxation for discrete sampling in VAEs and RL. - Efficient categorical sampling: the trick turns a multinomial sample into a simple argmax, often faster on accelerators.
Understanding this identity is a prerequisite for implementing Gumbel-Softmax and for reasoning about stochastic computation graphs.
Worked mini-example
import jax, jax.numpy as jnp
seed = 0
logits = jnp.array([1.0, 2.0, 3.0])
key = jax.random.PRNGKey(seed)
g = jax.random.gumbel(key, logits.shape) # sample Gumbel noise
sample = jnp.argmax(logits + g) # Gumbel-Max trick
return sample.astype(jnp.float32) # scalar float
Common pitfalls
- Identical seeds โ identical output: the Gumbel-Max trick is deterministic given the key. Different seeds will produce different samples even for the same logits.
- Cast to float32: the test contract expects a float scalar, not an int.
-
Use
jax.random.gumbel, not manual-log(-log(uniform)): both are equivalent, but the JAX built-in is numerically stable and idiomatic. -
Shape:
jax.random.gumbel(key, logits.shape)produces noise of the same shape as logits โ match this exactly.
Problem
Implement gumbel_argmax(seed, logits) that returns the index of the
Gumbel-Max sample as a float32 scalar.
seed is a float (cast to int to make a PRNGKey). logits is a 1-D
float array of shape (K,). Return a float32 scalar (the argmax index).
Hints
Sign in to attempt this problem and view the solution.