easy primitives

Uniform Sampling

Why this matters

jax.random.uniform is the workhorse for drawing samples from a uniform distribution. Unlike NumPy’s np.random.uniform, it takes an EXPLICIT KEY as its first argument — no global state. This makes your sampling code pure and reproducible: the same key always produces the same samples.

The function signature is jax.random.uniform(key, shape, minval=0.0, maxval=1.0). By default it draws from [0, 1), but you can shift the range by providing explicit minval and maxval kwargs. The shape argument must be a TUPLE.

Worked mini-example

import jax, jax.numpy as jnp

key = jax.random.PRNGKey(0)

# Default: [0, 1)
x = jax.random.uniform(key, (4,))

# Custom range: [-1, 1)
y = jax.random.uniform(key, (4,), minval=-1.0, maxval=1.0)

Common pitfalls

  • Shape must be a tuple: (int(n),) not int(n). A bare integer raises an error.
  • Defaults are [0, 1): forgetting minval/maxval gives you the wrong range silently.
  • PRNGKey needs an int: cast with int(seed) if seed arrives as a float.
  • Upper bound is exclusive: the range is [minval, maxval).

Problem

Implement uniform_in_range(seed, n, low, high) that draws n samples from Uniform(low, high) using jax.random.uniform.

Both seed and n arrive as floats; cast them to int inside the function.

One illustrative example (not from the test set):

  • uniform_in_range(0, 4, 0.0, 1.0) returns a 1-D array of shape (4,) with values in [0, 1), deterministic for seed 0.

Hints

jax random uniform

Sign in to attempt this problem and view the solution.