We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
jaxpr with jit
Why this matters
jax.make_jaxpr and jax.jit operate at different abstraction layers
of JAXβs compilation stack:
Python function
β make_jaxpr (tracing)
JAXPR (JAX IR β you can inspect this)
β jit (compilation)
XLA HLO (device IR β via jax.xla_computation)
β XLA
Device binary
Understanding this stack matters for:
-
Debugging unexpected retracing:
make_jaxprshows you what triggered a fresh trace (shape change, new static arg, etc.). -
Inspecting jit internals:
make_jaxpr(jit(f))(x)reveals the structure JAX sees around the jit boundary. -
Choosing the right tool:
make_jaxprfor JAX-level IR;jax.xla_computationfor device-level HLO.
Worked mini-example
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
return x ** 2
# Inspect the jaxpr JAX sees (includes pjit call)
print(jax.make_jaxpr(f)(jnp.ones(3)))
# Actually run it
print(f(jnp.ones(3))) # [1. 1. 1.]
Common pitfalls
-
make_jaxpr does NOT cache: each call re-traces. Unlike
jit, there is no compilation cache at themake_jaxprlevel. -
make_jaxpr shows the JAX IR, not XLA HLO: donβt expect to see device
ops (e.g.,
conv_general_dilated) β those live at the XLA level. -
jitcaches compiled code by shape+dtype: changing the shape ofxtriggers recompilation, visible as a newmake_jaxproutput.
Problem
Implement jaxpr_inside_jit(x) that:
-
Defines
inner(x) = jnp.sum(x ** 2)with@jax.jit. -
Returns
inner(x)β the jit-compiled result.
-
x: 1-D JAX array.
Returns: scalar β sum(x ** 2).
Hints
jax
jaxpr
jit
ir
Sign in to attempt this problem and view the solution.