We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Dirichlet Sampling
Why this matters
The Dirichlet distribution is the conjugate prior for the categorical distribution and lives on the probability simplex โ every sample is a probability vector whose components are non-negative and sum to 1. It is the go-to prior for topic models (LDA uses it to model document-topic and topic-word distributions), Bayesian categorical priors, and any scenario where you need a random distribution over K outcomes.
The concentration vector alpha controls the shape: symmetric alpha with all
entries equal to 1 gives a uniform distribution over the simplex; large alpha
concentrates samples near the center (similar proportions); small alpha (< 1)
concentrates samples near the corners (sparse, near-one-hot).
Worked mini-example
import jax, jax.numpy as jnp
key = jax.random.PRNGKey(0)
alpha = jnp.array([1.0, 1.0, 1.0]) # uniform over 3-simplex
# Draw 5 samples; each row sums to 1
samples = jax.random.dirichlet(key, alpha, shape=(5,))
# โ float32 array of shape (5, 3); each row is a valid probability vector
# With concentrated alpha: mostly mass on class 0
alpha_conc = jnp.array([50.0, 1.0, 1.0])
samples_conc = jax.random.dirichlet(key, alpha_conc, shape=(3,))
# โ each row has ~96% in index 0
Common pitfalls
-
alphamust be all positive: zero or negative concentration parameters are invalid for the Dirichlet distribution. -
Shape is the SAMPLE shape, not the alpha shape:
shape=(int(n),)givesnindependent draws; the output shape is(n, K)whereK = len(alpha). - Each row sums to 1 (not each column, not the whole matrix).
-
nmay arrive as a float: cast tointbefore use.
Problem
Implement dirichlet_sample(seed, alpha, n) that draws n samples from
Dirichlet(alpha).
alpha is a 1-D array of shape (K,) with all positive entries. Return a
2-D float32 array of shape (n, K) where each row sums to 1.
One illustrative example (not from the test set):
-
dirichlet_sample(0, jnp.array([1., 1., 1.]), 2.0)returns a float32 array of shape(2, 3)where each row sums to 1, deterministic for seed 0.
Hints
Sign in to attempt this problem and view the solution.