easy primitives

Beta Distribution Sampling

Why this matters

Beta(α, β) is the conjugate prior for the Bernoulli/binomial distribution — the workhorse of Bayesian binary classification and Thompson sampling. Every sample is a real number in (0, 1), making it ideal for modelling probabilities, mixing coefficients, and belief states.

The shape of the distribution is controlled entirely by α and β:

  • α = β = 1: uniform over (0, 1) — no prior knowledge.
  • α < 1, β < 1: bimodal; mass near 0 and 1 (sparse beliefs).
  • α > 1, β > 1: unimodal bell-shaped; concentrated around α/(α + β).
  • α >> β: concentrated near 1; α << β: concentrated near 0.

Thompson sampling draws one Beta sample per arm in a multi-armed bandit, selecting the arm with the highest draw — a cheap Bayesian strategy that automatically balances exploration and exploitation.

Worked mini-example

import jax, jax.numpy as jnp

key = jax.random.PRNGKey(0)

# Uniform prior — equivalent to Beta(1, 1)
samples = jax.random.beta(key, 1.0, 1.0, shape=(4,))
# → float32 array of shape (4,); each value in (0, 1)

# Concentrated near 0.5 — Beta(10, 10)
concentrated = jax.random.beta(key, 10.0, 10.0, shape=(3,))
# → tightly clustered around 0.5

Common pitfalls

  • α and β must be strictly positive: zero or negative values are undefined for the Beta distribution and will produce NaN.
  • n arrives as a float: cast to int before building the shape tuple.
  • Shape is a tuple: shape=(int(n),) — not just int(n).
  • Output is in (0, 1): do not confuse with Gaussian or uniform on ℝ.

Problem

Implement beta_sample(seed, alpha, beta_param, n) that draws n samples from Beta(alpha, beta_param).

seed, alpha, beta_param, and n are Python scalars (floats). Return a 1-D float32 array of shape (n,) with each value in (0, 1).

One illustrative example (not from the test set):

  • beta_sample(0, 1.0, 1.0, 4.0) returns a float32 array of shape (4,) with values uniformly spread in (0, 1), deterministic for seed 0.

Hints

jax random beta

Sign in to attempt this problem and view the solution.