easy end_to_end

Pure Function vs Impure

Implement a pure function that, given an external state array and a new value x, appends x to the state and returns both the new state and its mean — packed into a single output array.

Why this matters

JAX is built around tracing: when you wrap a function with jax.jit, JAX runs it once with abstract “tracer” objects in place of real arrays to record the operations. From the trace, JAX builds an optimized program. But this only works if your function is pure — its output depends only on its explicit inputs, and it has no side-effects on the outside world.

PyTorch lets you mutate global state, append to lists, print intermediate values, and use Python’s if on tensor values inside an autograd graph. JAX can’t do any of that inside a jitted function. The Python-level operations don’t get re-traced on each call; they happen once during tracing and bake-in their effects (or get silently dropped at run time).

The pure-function discipline is therefore the price of admission for jax.jit, jax.grad, jax.vmap, and the rest of the transformations. It’s also a great fit for parallelism: pure functions are trivial to run in any order, on any device, with the same result.

Worked mini-example

Impure (broken under jit):

state = []  # global

def append_and_mean(x):
    state.append(x)            # ← side-effect fires once (at trace time); compiled XLA ignores `state` entirely on later calls
    return sum(state) / len(state)

Pure version (the one this problem asks for):

def append_and_mean(state, x):
    new_state = jnp.concatenate([state, jnp.atleast_1d(x)])
    return new_state, jnp.mean(new_state)

The pure version takes state as an argument and returns the new state alongside the result. The caller threads the state through the next call.

Common pitfalls

  • Globals: any global X in a JAX function is an impure access; remove them.
  • print(): only fires during tracing (the first call into a jit-ed function or the first call after an input-shape change). On subsequent calls, JAX runs the compiled XLA program, not the Python function — so print is silent. Use jax.debug.print for runtime logging that runs every call.
  • Python lists with .append: same problem — the append happens once at trace time, then the trace bakes it in.
  • Class methods reading self: turn them into pure functions taking the relevant fields.
  • if x[0] > 0: on a traced value: tracer leak. Use jnp.where or jax.lax.cond.

Problem

Implement make_pure(state, x) that takes a 1-D JAX array state and a 1-D JAX array x of shape [1]. Return a packed 1-D array of shape (state.shape[0] + 2,) containing:

  • the first state.shape[0] + 1 entries are new_state (state with x appended),
  • the last entry is mean(new_state).

This packing convention exists so the test harness can deliver a single output array. In real code you would return the tuple (new_state, mean) directly.

Hints

jax purity tracing

Sign in to attempt this problem and view the solution.