hard primitives

AOT via jit.lower().compile()

Why this matters

JAX normally compiles functions just-in-time (JIT): the first call to a jitted function triggers XLA compilation for the observed shapes. In production you may want to pay that compilation cost at build time, not at inference time. The AOT (ahead-of-time) pattern lets you do exactly that:

jax.jit(f).lower(*example_args)   # trace + lower to StableHLO
            .compile()            # compile to device binary

The returned object is a compiled callable. Calling it at runtime skips tracing and compilation entirely — you get raw execution.

This matters for:

  • Latency-sensitive serving: first request is as fast as subsequent ones.
  • Build-time validation: catch shape/dtype mismatches before deployment.
  • Reproducible binaries: pin the compiled artifact to a specific shape.

Worked mini-example

import jax
import jax.numpy as jnp

x = jnp.ones(4)
f = jax.jit(jnp.sum)

# Compile once at "build time"
compiled = f.lower(x).compile()

# Fast path at "inference time"
result = compiled(x)   # → 4.0, no recompilation

Common pitfalls

  • Shape binding: the compiled function is bound to the example args’ shapes and dtypes. Calling with a different shape raises a runtime error.
  • Shape-polymorphic AOT: use jax.export with jax.ShapeDtypeStruct and symbolic dimensions if you need to handle multiple shapes from one binary.
  • lower() vs compile(): lower() alone only produces StableHLO text; compile() turns it into an executable. Both are needed for AOT.
  • Not re-jit-able: the compiled object does not cache additional shapes the way jax.jit does. Each shape needs its own lower().compile() call.

Problem

Implement aot_compiled_sum(x) that:

  1. Wraps jnp.sum with jax.jit.
  2. AOT-compiles it via .lower(x).compile().
  3. Calls the compiled function on x and returns the result.
  • x: 1-D JAX array.

Returns: scalar — sum(x).

Hints

jax jit aot lower

Sign in to attempt this problem and view the solution.