We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jit Retrace on Shape Change
Why this matters
jax.jit compiles a function once per unique (shape, dtype) combination
of its inputs. The compiled XLA program is cached; subsequent calls with the
same shapes hit the cache and pay only dispatch overhead (~microseconds).
When you call the same jitted function with a different shape, JAX must retrace and recompile โ paying full XLA compilation cost again (tens of milliseconds or more). This is called a retrace.
f = jax.jit(jnp.sum)
f(jnp.ones(3)) # compile for shape [3] โ ~50 ms
f(jnp.ones(3)) # cache hit โ ~5 ยตs
f(jnp.ones(4)) # retrace for shape [4] โ ~50 ms again
In production this matters when:
- Data loaders yield varying-length sequences (padding or bucketing avoids retraces).
- You accidentally pass scalars instead of 0-D arrays (different dtype = retrace).
- A grid-search changes array sizes between iterations.
Worked mini-example
import jax
import jax.numpy as jnp
f = jax.jit(jnp.sum)
a = jnp.array([1.0, 2.0, 3.0]) # shape [3]
b = jnp.array([4.0, 3.0, 2.0, 1.0]) # shape [4]
result = f(a) + f(b) # f compiled twice (different shapes)
# result == 6.0 + 10.0 == 16.0
Common pitfalls
- Compilation overhead dominates: calling a jitted function with many distinct shapes negates all performance benefits โ you pay compilation cost each time.
- Same shape โ single compile: if both arrays have the same shape JAX only compiles once; this is the common case.
-
Dtype also matters:
float32andfloat64arrays of the same shape trigger separate compilations.
Problem
Implement cache_independent_count(x_a, x_b) that:
-
Creates a single jitted version of
jnp.sum. -
Calls it on both
x_aandx_b(which may have different shapes). - Returns the sum of the two results as a scalar.
-
x_a: 1-D JAX array. -
x_b: 1-D JAX array (may differ in shape fromx_a).
Returns: scalar โ jnp.sum(x_a) + jnp.sum(x_b).
Example (not from the test set):
-
cache_independent_count(jnp.array([1.0, 2.0]), jnp.array([3.0]))returns6.0.
Hints
Sign in to attempt this problem and view the solution.