medium primitives

Gumbel Argmax (Categorical via Trick)

Why this matters

The Gumbel-Max trick shows that sampling from a categorical distribution Categorical(softmax(logits)) is equivalent to taking argmax(logits + g) where g ~ Gumbel(0, 1) i.i.d. This reformulation has two major uses in modern ML:

  1. Gumbel-Softmax / Concrete distribution (Jang et al., 2017; Maddison et al., 2017): replace argmax with softmax to get a differentiable relaxation for discrete sampling in VAEs and RL.
  2. Efficient categorical sampling: the trick turns a multinomial sample into a simple argmax, often faster on accelerators.

Understanding this identity is a prerequisite for implementing Gumbel-Softmax and for reasoning about stochastic computation graphs.

Worked mini-example

import jax, jax.numpy as jnp

seed = 0
logits = jnp.array([1.0, 2.0, 3.0])

key = jax.random.PRNGKey(seed)
g = jax.random.gumbel(key, logits.shape)   # sample Gumbel noise
sample = jnp.argmax(logits + g)             # Gumbel-Max trick
return sample.astype(jnp.float32)           # scalar float

Common pitfalls

  • Identical seeds โ†’ identical output: the Gumbel-Max trick is deterministic given the key. Different seeds will produce different samples even for the same logits.
  • Cast to float32: the test contract expects a float scalar, not an int.
  • Use jax.random.gumbel, not manual -log(-log(uniform)): both are equivalent, but the JAX built-in is numerically stable and idiomatic.
  • Shape: jax.random.gumbel(key, logits.shape) produces noise of the same shape as logits โ€” match this exactly.

Problem

Implement gumbel_argmax(seed, logits) that returns the index of the Gumbel-Max sample as a float32 scalar.

seed is a float (cast to int to make a PRNGKey). logits is a 1-D float array of shape (K,). Return a float32 scalar (the argmax index).

Hints

jax gumbel argmax categorical

Sign in to attempt this problem and view the solution.