We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
compilation_cache.set_cache_dir
Why this matters
By default, JAX’s compilation cache lives in memory for the lifetime of the current Python process. Every new process recompiles everything from scratch. For production services that run the same JIT-compiled functions millions of times across many process restarts (e.g., a model inference server), this in-memory-only cache is a bottleneck.
JAX provides a persistent compilation cache: XLA programs compiled during one process run are saved to disk and reloaded on the next run, skipping compilation entirely.
import jax
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
# OR (older API):
# from jax.experimental.compilation_cache import compilation_cache
# compilation_cache.set_cache_dir("/tmp/jax_cache")
After that call, JIT-compiled functions whose cache key matches a stored entry load the pre-compiled binary instead of recompiling.
Worked mini-example
import jax, jax.numpy as jnp
# Enable persistent cache (call before any JAX computation):
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
@jax.jit
def f(x):
return jnp.sum(x ** 2 + 1)
# First run: compiles and saves to /tmp/jax_cache
# Subsequent runs of the same Python script: loads from cache
print(f(jnp.array([1.0, 2.0, 3.0]))) # 17.0
Common pitfalls
- In-process speedup is a myth: the persistent cache only helps across process boundaries. Within a single process the in-memory cache is already active.
- Must call before JAX initialization: set the cache dir before any JAX computation so the config is applied at backend start.
- Cache invalidation: the cache key includes JAX version, XLA version, and device kind. Upgrading JAX automatically invalidates old entries.
Problem
Implement jit_with_warm_cache(x) that returns sum(x ** 2 + 1) using a
top-level @jax.jit-decorated helper. The lesson is conceptual: understand
that the persistent cache is enabled globally via jax_compilation_cache_dir
— the function itself simply demonstrates correct jit usage.
-
x: 1-D JAX array.
Returns: scalar — jnp.sum(x ** 2 + 1).
Examples (not from the test set):
-
jit_with_warm_cache(jnp.array([1.0, 2.0, 3.0]))→17.0
Hints
Sign in to attempt this problem and view the solution.