We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Pure Function vs Impure
Implement a pure function that, given an external state array and a new
value x, appends x to the state and returns both the new state and its
mean — packed into a single output array.
Why this matters
JAX is built around tracing: when you wrap a function with jax.jit, JAX
runs it once with abstract “tracer” objects in place of real arrays to record
the operations. From the trace, JAX builds an optimized program. But this
only works if your function is pure — its output depends only on its
explicit inputs, and it has no side-effects on the outside world.
PyTorch lets you mutate global state, append to lists, print intermediate
values, and use Python’s if on tensor values inside an autograd graph.
JAX can’t do any of that inside a jitted function. The Python-level
operations don’t get re-traced on each call; they happen once during tracing
and bake-in their effects (or get silently dropped at run time).
The pure-function discipline is therefore the price of admission for jax.jit,
jax.grad, jax.vmap, and the rest of the transformations. It’s also a great
fit for parallelism: pure functions are trivial to run in any order, on any
device, with the same result.
Worked mini-example
Impure (broken under jit):
state = [] # global
def append_and_mean(x):
state.append(x) # ← side-effect fires once (at trace time); compiled XLA ignores `state` entirely on later calls
return sum(state) / len(state)
Pure version (the one this problem asks for):
def append_and_mean(state, x):
new_state = jnp.concatenate([state, jnp.atleast_1d(x)])
return new_state, jnp.mean(new_state)
The pure version takes state as an argument and returns the new state alongside
the result. The caller threads the state through the next call.
Common pitfalls
-
Globals: any
global Xin a JAX function is an impure access; remove them. -
print(): only fires during tracing (the first call into ajit-ed function or the first call after an input-shape change). On subsequent calls, JAX runs the compiled XLA program, not the Python function — soprintis silent. Usejax.debug.printfor runtime logging that runs every call. -
Python lists with
.append: same problem — the append happens once at trace time, then the trace bakes it in. -
Class methods reading
self: turn them into pure functions taking the relevant fields. -
if x[0] > 0:on a traced value: tracer leak. Usejnp.whereorjax.lax.cond.
Problem
Implement make_pure(state, x) that takes a 1-D JAX array state and a
1-D JAX array x of shape [1]. Return a packed 1-D array of shape
(state.shape[0] + 2,) containing:
-
the first
state.shape[0] + 1entries arenew_state(state with x appended), -
the last entry is
mean(new_state).
This packing convention exists so the test harness can deliver a single
output array. In real code you would return the tuple (new_state, mean) directly.
Hints
Sign in to attempt this problem and view the solution.