We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
primitives
AOT via jit.lower().compile()
Why this matters
JAX normally compiles functions just-in-time (JIT): the first call to a jitted function triggers XLA compilation for the observed shapes. In production you may want to pay that compilation cost at build time, not at inference time. The AOT (ahead-of-time) pattern lets you do exactly that:
jax.jit(f).lower(*example_args) # trace + lower to StableHLO
.compile() # compile to device binary
The returned object is a compiled callable. Calling it at runtime skips tracing and compilation entirely — you get raw execution.
This matters for:
- Latency-sensitive serving: first request is as fast as subsequent ones.
- Build-time validation: catch shape/dtype mismatches before deployment.
- Reproducible binaries: pin the compiled artifact to a specific shape.
Worked mini-example
import jax
import jax.numpy as jnp
x = jnp.ones(4)
f = jax.jit(jnp.sum)
# Compile once at "build time"
compiled = f.lower(x).compile()
# Fast path at "inference time"
result = compiled(x) # → 4.0, no recompilation
Common pitfalls
- Shape binding: the compiled function is bound to the example args’ shapes and dtypes. Calling with a different shape raises a runtime error.
-
Shape-polymorphic AOT: use
jax.exportwithjax.ShapeDtypeStructand symbolic dimensions if you need to handle multiple shapes from one binary. -
lower()vscompile():lower()alone only produces StableHLO text;compile()turns it into an executable. Both are needed for AOT. -
Not re-jit-able: the compiled object does not cache additional shapes
the way
jax.jitdoes. Each shape needs its ownlower().compile()call.
Problem
Implement aot_compiled_sum(x) that:
-
Wraps
jnp.sumwithjax.jit. -
AOT-compiles it via
.lower(x).compile(). -
Calls the compiled function on
xand returns the result.
-
x: 1-D JAX array.
Returns: scalar — sum(x).
Hints
jax
jit
aot
lower
Sign in to attempt this problem and view the solution.