medium primitives

Categorical Sampling

Why this matters

jax.random.categorical(key, logits, axis=-1, shape=None) samples discrete indices from a categorical distribution whose probabilities are the softmax of the logits. It operates on LOGITS (unnormalized log-probabilities), not raw probabilities. This is the correct primitive for ancestral sampling from a language model’s output, token selection during inference, and any code that needs differentiable-path categorical draws.

The shape argument here is the SAMPLE shape — how many independent draws you want — not the shape of the logits.

Worked mini-example

import jax, jax.numpy as jnp

key = jax.random.PRNGKey(0)
logits = jnp.array([1.0, 2.0, 3.0])   # 3-class pmf via softmax

# Draw 5 samples (one axis, 3 classes → 5 indices in {0,1,2})
samples = jax.random.categorical(key, logits, shape=(5,))
# → int32 array of shape (5,)

# Cast to float for downstream arithmetic
samples_f = samples.astype(jnp.float32)

Common pitfalls

  • Pass LOGITS, not probabilities: categorical applies softmax internally. Passing a probability vector gives the wrong distribution (double-softmax).
  • shape is the SAMPLE shape: shape=(5,) means 5 independent draws, not a (5, K) logit matrix.
  • Returns int32: cast to float if your downstream code expects floats.
  • axis defaults to -1: for a 1-D logits vector this is correct; for batched logits matrices be explicit.

Problem

Implement categorical_sample(seed, logits, n) that draws n independent categorical samples from a distribution defined by logits.

n arrives as a float; cast it to int. Return a 1-D float32 array of shape (n,) containing the sampled indices.

One illustrative example (not from the test set):

  • categorical_sample(0, jnp.array([1.0, 1.0, 1.0]), 5.0) returns a 1-D float32 array of shape (5,) with indices in {0, 1, 2}, deterministic for seed 0.

Hints

jax random categorical

Sign in to attempt this problem and view the solution.