medium primitives

jaxpr with Multiple Args

Why this matters

jax.make_jaxpr(f)(*args) returns a ClosedJaxpr — the IR wrapper that captures both the traced computation and its constant literals. Drilling into the inner .jaxpr object exposes the structural metadata that the compiler uses: invars (input variables), outvars (output variables), and eqns (primitive equations).

Counting invars tells you how many distinct traced arguments the function accepted after pytree-flattening. For a two-argument function f(x, y) this is exactly 2, regardless of array shapes. This kind of introspection is useful when debugging higher-order transformations (vmap, grad, jit) where you want to verify that the compiler sees the arity you intended.

Worked mini-example

import jax
import jax.numpy as jnp

x = jnp.ones(3)
y = jnp.ones(3)

closed = jax.make_jaxpr(lambda x, y: x + y)(x, y)
inner  = closed.jaxpr           # the inner Jaxpr
print(inner.invars)             # [a, b]  — two abstract vars
print(len(inner.invars))        # 2

Common pitfalls

  • make_jaxpr returns a ClosedJaxpr, not a Jaxpr: you must access .jaxpr to reach the inner Jaxpr with its invars, outvars, eqns.
  • Pytree flattening happens first: a function accepting one dict {"x": arr} still produces multiple invars if the dict has multiple leaves.
  • Don’t confuse invars with constvars: constvars captures closed-over constants; invars are the traced function arguments.

Problem

Implement jaxpr_invar_count(x, y) that:

  1. Calls jax.make_jaxpr(lambda x, y: x + y)(x, y) to trace a two-argument addition.
  2. Accesses the inner Jaxpr via .jaxpr.
  3. Returns the number of invars as a jnp.float32 scalar.
  • x, y: 1-D JAX arrays of the same shape.

Returns: scalar (float32) — always 2.0 for this function.

Hints

jax make-jaxpr ir

Sign in to attempt this problem and view the solution.