We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
Counting Jaxpr Equations
Why this matters
jax.make_jaxpr(f)(args) returns a Jaxpr — JAX’s intermediate
representation (IR). Every primitive operation in f becomes one equation
(jaxpr.eqns). Counting these equations gives you a quick complexity
measure of a traced function, without actually running it.
Practical uses:
-
Debugging traces: confirm the compiler is generating the ops you
expect (e.g.,
mulnotmatmulfor a broadcast multiply). - Verifying fusion: check that multiple elementwise ops collapsed into one fused kernel (fewer equations than source lines).
- Comparing implementations: ensure a refactored function produces the same IR as the original.
Worked mini-example
import jax
import jax.numpy as jnp
x = jnp.ones(3)
jaxpr = jax.make_jaxpr(lambda x: jnp.sum(x))(x)
print(len(jaxpr.eqns)) # 1 — just a reduce_sum
For lambda x: jnp.sum(x ** 2 + jnp.exp(x)):
-
integer_pow(x, 2)— 1 equation -
exp(x)— 1 equation -
add(pow_out, exp_out)— 1 equation -
reduce_sum(...)— 1 equation → 4 equations total (may vary by JAX version).
Common pitfalls
- Counting source lines instead of ops: JAX traces primitives, not Python statements. One Python expression can produce multiple equations.
- Version sensitivity: the equation count can shift between JAX versions as the compiler changes how it lowers ops.
- Shape doesn’t matter here: the Jaxpr structure (and equation count) depends on input dtype and rank, not on the specific shape dimensions for elementwise ops.
Problem
Implement count_jaxpr_eqns(x) that:
-
Calls
jax.make_jaxpr(lambda x: jnp.sum(x ** 2 + jnp.exp(x)))(x). -
Returns the number of equations in the Jaxpr as a
jnp.float32scalar.
-
x: 1-D JAX array.
Returns: scalar (float32) — float(len(jaxpr.eqns)).
Hints
jax
jaxpr
ir-inspection
Sign in to attempt this problem and view the solution.