We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Random Permutation
Why this matters
Shuffling data is one of the most fundamental operations in machine learning:
it breaks the ordering bias in gradient descent, enables random mini-batch
construction, and randomizes the order of examples in curriculum learning.
jax.random.permutation(key, x) returns a randomly permuted COPY of x —
consistent with JAX’s functional, immutable design.
Because JAX uses an explicit key, a given (key, x) pair always produces the
same shuffled array, making your training pipelines reproducible across runs,
machines, and frameworks.
Worked mini-example
import jax, jax.numpy as jnp
key = jax.random.PRNGKey(0)
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
shuffled = jax.random.permutation(key, x)
# → float32 array of shape (5,); same values, different order
# For reproducible batching:
epoch_key = jax.random.PRNGKey(epoch)
indices = jax.random.permutation(epoch_key, jnp.arange(len(dataset)))
Common pitfalls
-
JAX arrays are immutable:
permutationreturns a new array; the original is unchanged. -
Same key → same permutation: use
jax.random.fold_inor split keys each epoch to get different orderings. -
Integer arrays work too:
jax.random.permutation(key, jnp.arange(n))shuffles indices, which is the standard batch-sampler pattern. -
independentkwarg (JAX ≥ 0.4):independent=Trueshuffles each row of a 2-D array independently; default (False) shuffles the leading axis only.
Problem
Implement shuffle_array(seed, x) that returns a random permutation of the
1-D array x.
seed is a Python scalar (float); x is a 1-D JAX array. Return a 1-D
array of the same shape as x containing the same values in a random order.
One illustrative example (not from the test set):
-
shuffle_array(0, jnp.array([1., 2., 3., 4., 5.]))returns a float32 array of shape(5,)— a permutation of the input, deterministic for seed 0.
Hints
Sign in to attempt this problem and view the solution.