easy primitives

Bernoulli Mask Sampling

Why this matters

jax.random.bernoulli(key, p, shape) draws Boolean samples where each entry is True with probability p. This is the foundation of dropout masks, binary stochastic gates, and variational Bernoulli latents. In practice, you almost always cast the Boolean array to float so you can multiply element-wise with activations: mask.astype(jnp.float32).

Understanding Bernoulli sampling in JAX is a prerequisite for implementing standard dropout or stochastic depth from scratch.

Worked mini-example

import jax, jax.numpy as jnp

key = jax.random.PRNGKey(0)

# 3x3 dropout mask, keep probability 0.5
mask_bool = jax.random.bernoulli(key, 0.5, (3, 3))
# โ†’ bool array of shape (3, 3)

mask_float = mask_bool.astype(jnp.float32)
# โ†’ float32 array of 0.0 / 1.0

# Apply to activations:
x = jnp.ones((3, 3))
dropped = x * mask_float

Common pitfalls

  • Returns BOOLEANS, not floats: the raw output is bool. Cast to float before doing arithmetic: mask.astype(jnp.float32).
  • Shape args must be Python ints: int(shape[0]), not the raw float value from a JAX array element.
  • PRNGKey needs an int: jax.random.PRNGKey(int(seed)).
  • p=0.0 is valid and produces an all-zero mask (never True).

Problem

Implement bernoulli_mask(seed, p, shape) where shape is a 1-D JAX array of length 2 containing [H, W] as floats. Draw a Bernoulli(p) mask of shape (H, W) and return it as a float32 array of 0.0s and 1.0s.

One illustrative example (not from the test set):

  • bernoulli_mask(0, 0.5, jnp.array([3.0, 3.0])) returns a float32 array of shape (3, 3) with 0.0/1.0 entries, deterministic for seed 0.

Hints

jax random bernoulli

Sign in to attempt this problem and view the solution.