We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Bernoulli Mask Sampling
Why this matters
jax.random.bernoulli(key, p, shape) draws Boolean samples where each entry
is True with probability p. This is the foundation of dropout masks,
binary stochastic gates, and variational Bernoulli latents. In practice, you
almost always cast the Boolean array to float so you can multiply element-wise
with activations: mask.astype(jnp.float32).
Understanding Bernoulli sampling in JAX is a prerequisite for implementing standard dropout or stochastic depth from scratch.
Worked mini-example
import jax, jax.numpy as jnp
key = jax.random.PRNGKey(0)
# 3x3 dropout mask, keep probability 0.5
mask_bool = jax.random.bernoulli(key, 0.5, (3, 3))
# โ bool array of shape (3, 3)
mask_float = mask_bool.astype(jnp.float32)
# โ float32 array of 0.0 / 1.0
# Apply to activations:
x = jnp.ones((3, 3))
dropped = x * mask_float
Common pitfalls
-
Returns BOOLEANS, not floats: the raw output is
bool. Cast to float before doing arithmetic:mask.astype(jnp.float32). -
Shape args must be Python ints:
int(shape[0]), not the raw float value from a JAX array element. -
PRNGKey needs an int:
jax.random.PRNGKey(int(seed)). -
p=0.0is valid and produces an all-zero mask (never True).
Problem
Implement bernoulli_mask(seed, p, shape) where shape is a 1-D JAX array
of length 2 containing [H, W] as floats. Draw a Bernoulli(p) mask of shape
(H, W) and return it as a float32 array of 0.0s and 1.0s.
One illustrative example (not from the test set):
-
bernoulli_mask(0, 0.5, jnp.array([3.0, 3.0]))returns a float32 array of shape(3, 3)with 0.0/1.0 entries, deterministic for seed 0.
Hints
Sign in to attempt this problem and view the solution.