medium primitives

jit Retrace on Shape Change

Why this matters

jax.jit compiles a function once per unique (shape, dtype) combination of its inputs. The compiled XLA program is cached; subsequent calls with the same shapes hit the cache and pay only dispatch overhead (~microseconds).

When you call the same jitted function with a different shape, JAX must retrace and recompile โ€” paying full XLA compilation cost again (tens of milliseconds or more). This is called a retrace.

f = jax.jit(jnp.sum)

f(jnp.ones(3))   # compile for shape [3] โ€” ~50 ms
f(jnp.ones(3))   # cache hit โ€” ~5 ยตs
f(jnp.ones(4))   # retrace for shape [4] โ€” ~50 ms again

In production this matters when:

  • Data loaders yield varying-length sequences (padding or bucketing avoids retraces).
  • You accidentally pass scalars instead of 0-D arrays (different dtype = retrace).
  • A grid-search changes array sizes between iterations.

Worked mini-example

import jax
import jax.numpy as jnp

f = jax.jit(jnp.sum)

a = jnp.array([1.0, 2.0, 3.0])    # shape [3]
b = jnp.array([4.0, 3.0, 2.0, 1.0])  # shape [4]

result = f(a) + f(b)   # f compiled twice (different shapes)
# result == 6.0 + 10.0 == 16.0

Common pitfalls

  • Compilation overhead dominates: calling a jitted function with many distinct shapes negates all performance benefits โ€” you pay compilation cost each time.
  • Same shape โ†’ single compile: if both arrays have the same shape JAX only compiles once; this is the common case.
  • Dtype also matters: float32 and float64 arrays of the same shape trigger separate compilations.

Problem

Implement cache_independent_count(x_a, x_b) that:

  1. Creates a single jitted version of jnp.sum.
  2. Calls it on both x_a and x_b (which may have different shapes).
  3. Returns the sum of the two results as a scalar.
  • x_a: 1-D JAX array.
  • x_b: 1-D JAX array (may differ in shape from x_a).

Returns: scalar โ€” jnp.sum(x_a) + jnp.sum(x_b).

Example (not from the test set):

  • cache_independent_count(jnp.array([1.0, 2.0]), jnp.array([3.0])) returns 6.0.

Hints

jax jit retracing

Sign in to attempt this problem and view the solution.