medium primitives

jaxpr with jit

Why this matters

jax.make_jaxpr and jax.jit operate at different abstraction layers of JAX’s compilation stack:

Python function
    ↓  make_jaxpr (tracing)
JAXPR  (JAX IR β€” you can inspect this)
    ↓  jit (compilation)
XLA HLO  (device IR β€” via jax.xla_computation)
    ↓  XLA
Device binary

Understanding this stack matters for:

  • Debugging unexpected retracing: make_jaxpr shows you what triggered a fresh trace (shape change, new static arg, etc.).
  • Inspecting jit internals: make_jaxpr(jit(f))(x) reveals the structure JAX sees around the jit boundary.
  • Choosing the right tool: make_jaxpr for JAX-level IR; jax.xla_computation for device-level HLO.

Worked mini-example

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
    return x ** 2

# Inspect the jaxpr JAX sees (includes pjit call)
print(jax.make_jaxpr(f)(jnp.ones(3)))

# Actually run it
print(f(jnp.ones(3)))  # [1. 1. 1.]

Common pitfalls

  • make_jaxpr does NOT cache: each call re-traces. Unlike jit, there is no compilation cache at the make_jaxpr level.
  • make_jaxpr shows the JAX IR, not XLA HLO: don’t expect to see device ops (e.g., conv_general_dilated) β€” those live at the XLA level.
  • jit caches compiled code by shape+dtype: changing the shape of x triggers recompilation, visible as a new make_jaxpr output.

Problem

Implement jaxpr_inside_jit(x) that:

  1. Defines inner(x) = jnp.sum(x ** 2) with @jax.jit.
  2. Returns inner(x) β€” the jit-compiled result.
  • x: 1-D JAX array.

Returns: scalar β€” sum(x ** 2).

Hints

jax jaxpr jit ir

Sign in to attempt this problem and view the solution.