medium primitives

Compilation Cache Key

Why this matters

JAXโ€™s JIT compilation cache is keyed on function identity + input shape + input dtype + any closed-over Python values. When you define a @jax.jit-decorated function at module top-level, its identity is stable across calls and the cache works as expected. When you define the same function inside another function, a brand-new jit object is created on every call โ€” guaranteeing a cache miss and a full recompile.

# GOOD โ€” stable cache key
@jax.jit
def _stable_sum_sq(x):
    return jnp.sum(x ** 2)

# BAD โ€” new jit object every call โ†’ always recompiles
def hot_call(x):
    f = jax.jit(lambda y: jnp.sum(y ** 2))
    return f(x)

Worked mini-example

import jax, jax.numpy as jnp

@jax.jit
def sum_sq(x):
    return jnp.sum(x ** 2)

x = jnp.array([1.0, 2.0, 3.0])
print(sum_sq(x))  # 14.0 โ€” compiled once
print(sum_sq(x))  # 14.0 โ€” cache hit, no recompile

Common pitfalls

  • Lambdas inside hot loops: for _ in range(N): f = jax.jit(lambda: ...) recompiles N times. Hoist the jitted function out of the loop.
  • Closed-over Python values: if your function closes over a Python integer or float that changes, each distinct value is a new cache key. Use static_argnums instead.
  • functools.partial: each partial call creates a new object, so a new cache key. Pre-compute partial once at module level.

Problem

Implement jit_cached_call(x) that returns sum(x ** 2) using a top-level @jax.jit-decorated helper so the cache key is stable.

  • x: 1-D JAX array.

Returns: scalar โ€” jnp.sum(x ** 2).

Examples (not from the test set):

  • jit_cached_call(jnp.array([3.0, 4.0])) โ†’ 25.0

Hints

jax compilation-cache jit

Sign in to attempt this problem and view the solution.