easy primitives

Random Permutation

Why this matters

Shuffling data is one of the most fundamental operations in machine learning: it breaks the ordering bias in gradient descent, enables random mini-batch construction, and randomizes the order of examples in curriculum learning. jax.random.permutation(key, x) returns a randomly permuted COPY of x — consistent with JAX’s functional, immutable design.

Because JAX uses an explicit key, a given (key, x) pair always produces the same shuffled array, making your training pipelines reproducible across runs, machines, and frameworks.

Worked mini-example

import jax, jax.numpy as jnp

key = jax.random.PRNGKey(0)
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])

shuffled = jax.random.permutation(key, x)
# → float32 array of shape (5,); same values, different order

# For reproducible batching:
epoch_key = jax.random.PRNGKey(epoch)
indices   = jax.random.permutation(epoch_key, jnp.arange(len(dataset)))

Common pitfalls

  • JAX arrays are immutable: permutation returns a new array; the original is unchanged.
  • Same key → same permutation: use jax.random.fold_in or split keys each epoch to get different orderings.
  • Integer arrays work too: jax.random.permutation(key, jnp.arange(n)) shuffles indices, which is the standard batch-sampler pattern.
  • independent kwarg (JAX ≥ 0.4): independent=True shuffles each row of a 2-D array independently; default (False) shuffles the leading axis only.

Problem

Implement shuffle_array(seed, x) that returns a random permutation of the 1-D array x.

seed is a Python scalar (float); x is a 1-D JAX array. Return a 1-D array of the same shape as x containing the same values in a random order.

One illustrative example (not from the test set):

  • shuffle_array(0, jnp.array([1., 2., 3., 4., 5.])) returns a float32 array of shape (5,) — a permutation of the input, deterministic for seed 0.

Hints

jax random permutation shuffle

Sign in to attempt this problem and view the solution.