We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Compilation Cache Key
Why this matters
JAXโs JIT compilation cache is keyed on function identity + input
shape + input dtype + any closed-over Python values. When you define a
@jax.jit-decorated function at module top-level, its identity is
stable across calls and the cache works as expected. When you define the
same function inside another function, a brand-new jit object is
created on every call โ guaranteeing a cache miss and a full recompile.
# GOOD โ stable cache key
@jax.jit
def _stable_sum_sq(x):
return jnp.sum(x ** 2)
# BAD โ new jit object every call โ always recompiles
def hot_call(x):
f = jax.jit(lambda y: jnp.sum(y ** 2))
return f(x)
Worked mini-example
import jax, jax.numpy as jnp
@jax.jit
def sum_sq(x):
return jnp.sum(x ** 2)
x = jnp.array([1.0, 2.0, 3.0])
print(sum_sq(x)) # 14.0 โ compiled once
print(sum_sq(x)) # 14.0 โ cache hit, no recompile
Common pitfalls
-
Lambdas inside hot loops:
for _ in range(N): f = jax.jit(lambda: ...)recompilesNtimes. Hoist the jitted function out of the loop. -
Closed-over Python values: if your function closes over a Python
integer or float that changes, each distinct value is a new cache key.
Use
static_argnumsinstead. -
functools.partial: eachpartialcall creates a new object, so a new cache key. Pre-computepartialonce at module level.
Problem
Implement jit_cached_call(x) that returns sum(x ** 2) using a
top-level @jax.jit-decorated helper so the cache key is stable.
-
x: 1-D JAX array.
Returns: scalar โ jnp.sum(x ** 2).
Examples (not from the test set):
-
jit_cached_call(jnp.array([3.0, 4.0]))โ25.0
Hints
Sign in to attempt this problem and view the solution.