medium primitives

Dirichlet Sampling

Why this matters

The Dirichlet distribution is the conjugate prior for the categorical distribution and lives on the probability simplex โ€” every sample is a probability vector whose components are non-negative and sum to 1. It is the go-to prior for topic models (LDA uses it to model document-topic and topic-word distributions), Bayesian categorical priors, and any scenario where you need a random distribution over K outcomes.

The concentration vector alpha controls the shape: symmetric alpha with all entries equal to 1 gives a uniform distribution over the simplex; large alpha concentrates samples near the center (similar proportions); small alpha (< 1) concentrates samples near the corners (sparse, near-one-hot).

Worked mini-example

import jax, jax.numpy as jnp

key = jax.random.PRNGKey(0)
alpha = jnp.array([1.0, 1.0, 1.0])   # uniform over 3-simplex

# Draw 5 samples; each row sums to 1
samples = jax.random.dirichlet(key, alpha, shape=(5,))
# โ†’ float32 array of shape (5, 3); each row is a valid probability vector

# With concentrated alpha: mostly mass on class 0
alpha_conc = jnp.array([50.0, 1.0, 1.0])
samples_conc = jax.random.dirichlet(key, alpha_conc, shape=(3,))
# โ†’ each row has ~96% in index 0

Common pitfalls

  • alpha must be all positive: zero or negative concentration parameters are invalid for the Dirichlet distribution.
  • Shape is the SAMPLE shape, not the alpha shape: shape=(int(n),) gives n independent draws; the output shape is (n, K) where K = len(alpha).
  • Each row sums to 1 (not each column, not the whole matrix).
  • n may arrive as a float: cast to int before use.

Problem

Implement dirichlet_sample(seed, alpha, n) that draws n samples from Dirichlet(alpha).

alpha is a 1-D array of shape (K,) with all positive entries. Return a 2-D float32 array of shape (n, K) where each row sums to 1.

One illustrative example (not from the test set):

  • dirichlet_sample(0, jnp.array([1., 1., 1.]), 2.0) returns a float32 array of shape (2, 3) where each row sums to 1, deterministic for seed 0.

Hints

jax random dirichlet

Sign in to attempt this problem and view the solution.