We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Uniform Sampling
Why this matters
jax.random.uniform is the workhorse for drawing samples from a uniform
distribution. Unlike NumPy’s np.random.uniform, it takes an EXPLICIT KEY
as its first argument — no global state. This makes your sampling code
pure and reproducible: the same key always produces the same samples.
The function signature is jax.random.uniform(key, shape, minval=0.0, maxval=1.0).
By default it draws from [0, 1), but you can shift the range by providing
explicit minval and maxval kwargs. The shape argument must be a TUPLE.
Worked mini-example
import jax, jax.numpy as jnp
key = jax.random.PRNGKey(0)
# Default: [0, 1)
x = jax.random.uniform(key, (4,))
# Custom range: [-1, 1)
y = jax.random.uniform(key, (4,), minval=-1.0, maxval=1.0)
Common pitfalls
-
Shape must be a tuple:
(int(n),)notint(n). A bare integer raises an error. -
Defaults are
[0, 1): forgettingminval/maxvalgives you the wrong range silently. -
PRNGKey needs an int: cast with
int(seed)ifseedarrives as a float. -
Upper bound is exclusive: the range is
[minval, maxval).
Problem
Implement uniform_in_range(seed, n, low, high) that draws n samples from
Uniform(low, high) using jax.random.uniform.
Both seed and n arrive as floats; cast them to int inside the function.
One illustrative example (not from the test set):
-
uniform_in_range(0, 4, 0.0, 1.0)returns a 1-D array of shape(4,)with values in[0, 1), deterministic for seed 0.
Hints
Sign in to attempt this problem and view the solution.