We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Categorical Sampling
Why this matters
jax.random.categorical(key, logits, axis=-1, shape=None) samples discrete
indices from a categorical distribution whose probabilities are the softmax of
the logits. It operates on LOGITS (unnormalized log-probabilities), not raw
probabilities. This is the correct primitive for ancestral sampling from a
language model’s output, token selection during inference, and any code that
needs differentiable-path categorical draws.
The shape argument here is the SAMPLE shape — how many independent draws
you want — not the shape of the logits.
Worked mini-example
import jax, jax.numpy as jnp
key = jax.random.PRNGKey(0)
logits = jnp.array([1.0, 2.0, 3.0]) # 3-class pmf via softmax
# Draw 5 samples (one axis, 3 classes → 5 indices in {0,1,2})
samples = jax.random.categorical(key, logits, shape=(5,))
# → int32 array of shape (5,)
# Cast to float for downstream arithmetic
samples_f = samples.astype(jnp.float32)
Common pitfalls
-
Pass LOGITS, not probabilities:
categoricalapplies softmax internally. Passing a probability vector gives the wrong distribution (double-softmax). -
shapeis the SAMPLE shape:shape=(5,)means 5 independent draws, not a (5, K) logit matrix. - Returns int32: cast to float if your downstream code expects floats.
- axis defaults to -1: for a 1-D logits vector this is correct; for batched logits matrices be explicit.
Problem
Implement categorical_sample(seed, logits, n) that draws n independent
categorical samples from a distribution defined by logits.
n arrives as a float; cast it to int. Return a 1-D float32 array of shape
(n,) containing the sampled indices.
One illustrative example (not from the test set):
-
categorical_sample(0, jnp.array([1.0, 1.0, 1.0]), 5.0)returns a 1-D float32 array of shape(5,)with indices in {0, 1, 2}, deterministic for seed 0.
Hints
Sign in to attempt this problem and view the solution.