We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jaxpr with Multiple Args
Why this matters
jax.make_jaxpr(f)(*args) returns a ClosedJaxpr — the IR wrapper that
captures both the traced computation and its constant literals. Drilling into
the inner .jaxpr object exposes the structural metadata that the compiler
uses: invars (input variables), outvars (output variables), and eqns
(primitive equations).
Counting invars tells you how many distinct traced arguments the function
accepted after pytree-flattening. For a two-argument function f(x, y) this
is exactly 2, regardless of array shapes. This kind of introspection is useful
when debugging higher-order transformations (vmap, grad, jit) where you
want to verify that the compiler sees the arity you intended.
Worked mini-example
import jax
import jax.numpy as jnp
x = jnp.ones(3)
y = jnp.ones(3)
closed = jax.make_jaxpr(lambda x, y: x + y)(x, y)
inner = closed.jaxpr # the inner Jaxpr
print(inner.invars) # [a, b] — two abstract vars
print(len(inner.invars)) # 2
Common pitfalls
-
make_jaxprreturns aClosedJaxpr, not aJaxpr: you must access.jaxprto reach the innerJaxprwith itsinvars,outvars,eqns. -
Pytree flattening happens first: a function accepting one dict
{"x": arr}still produces multiple invars if the dict has multiple leaves. -
Don’t confuse
invarswithconstvars:constvarscaptures closed-over constants;invarsare the traced function arguments.
Problem
Implement jaxpr_invar_count(x, y) that:
-
Calls
jax.make_jaxpr(lambda x, y: x + y)(x, y)to trace a two-argument addition. -
Accesses the inner
Jaxprvia.jaxpr. -
Returns the number of
invarsas ajnp.float32scalar.
-
x,y: 1-D JAX arrays of the same shape.
Returns: scalar (float32) — always 2.0 for this function.
Hints
Sign in to attempt this problem and view the solution.