medium primitives

Counting Jaxpr Equations

Why this matters

jax.make_jaxpr(f)(args) returns a Jaxpr — JAX’s intermediate representation (IR). Every primitive operation in f becomes one equation (jaxpr.eqns). Counting these equations gives you a quick complexity measure of a traced function, without actually running it.

Practical uses:

  • Debugging traces: confirm the compiler is generating the ops you expect (e.g., mul not matmul for a broadcast multiply).
  • Verifying fusion: check that multiple elementwise ops collapsed into one fused kernel (fewer equations than source lines).
  • Comparing implementations: ensure a refactored function produces the same IR as the original.

Worked mini-example

import jax
import jax.numpy as jnp

x = jnp.ones(3)
jaxpr = jax.make_jaxpr(lambda x: jnp.sum(x))(x)
print(len(jaxpr.eqns))   # 1 — just a reduce_sum

For lambda x: jnp.sum(x ** 2 + jnp.exp(x)):

  • integer_pow(x, 2) — 1 equation
  • exp(x) — 1 equation
  • add(pow_out, exp_out) — 1 equation
  • reduce_sum(...) — 1 equation → 4 equations total (may vary by JAX version).

Common pitfalls

  • Counting source lines instead of ops: JAX traces primitives, not Python statements. One Python expression can produce multiple equations.
  • Version sensitivity: the equation count can shift between JAX versions as the compiler changes how it lowers ops.
  • Shape doesn’t matter here: the Jaxpr structure (and equation count) depends on input dtype and rank, not on the specific shape dimensions for elementwise ops.

Problem

Implement count_jaxpr_eqns(x) that:

  1. Calls jax.make_jaxpr(lambda x: jnp.sum(x ** 2 + jnp.exp(x)))(x).
  2. Returns the number of equations in the Jaxpr as a jnp.float32 scalar.
  • x: 1-D JAX array.

Returns: scalar (float32) — float(len(jaxpr.eqns)).

Hints

jax jaxpr ir-inspection

Sign in to attempt this problem and view the solution.