We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
Pretty-Printed jaxpr Length
Why this matters
JAX’s Jaxpr object ships two human-readable representations:
-
str(closed_jaxpr)— uses theClosedJaxprstring format, which includes constant literals inline. -
jaxpr.pretty_print()— the innerJaxpr‘s dedicated pretty-printer, which produces the canonical{ lambda ; ... let ... in (...) }format used in documentation and academic papers.
Understanding pretty_print() is useful when:
- Comparing implementations side-by-side: two functions should produce the same primitive sequence regardless of Python-level differences.
-
Documenting compiled programs: the
pretty_printformat is stable enough for inclusion in papers and design docs. - Teaching JAX internals: the lambda-calculus style output maps directly to JAX’s functional IR semantics.
Worked mini-example
import jax
import jax.numpy as jnp
x = jnp.ones(3)
closed = jax.make_jaxpr(lambda x: jnp.sum(x ** 2))(x)
pp = closed.jaxpr.pretty_print()
print(pp)
# { lambda ; a:f32[3]. let
# b:f32[3] = integer_pow[y=2] a
# c:f32[] = reduce_sum[axes=(0,)] b
# in (c,) }
print(len(pp)) # 126 (may vary with JAX version)
Common pitfalls
-
pretty_print()vsstr(jaxpr):pretty_print()is called on the innerJaxpr(closed.jaxpr.pretty_print()), whilestr()is called on theClosedJaxprdirectly. They may differ in whitespace and constant representation. - Version sensitivity: the exact string (and its length) can change between JAX releases. This problem computes the length at runtime so it is always self-consistent — don’t hard-code the expected value.
-
Shape independence (for fixed rank): for elementwise + reduce_sum the
Jaxpr structure is shape-independent; a
[3]and a[5]input produce the same pretty-printed length (only the type annotation changes).
Problem
Implement pp_jaxpr_len(x) that:
-
Calls
jax.make_jaxpr(lambda x: jnp.sum(x ** 2))(x)to trace the sum-of-squares function onx. -
Accesses the inner Jaxpr via
.jaxpr. -
Calls
.pretty_print()on it and returns the character length as ajnp.float32scalar.
-
x: 1-D JAX array.
Returns: scalar (float32) — character length of the pretty-printed Jaxpr.
Hints
jax
make-jaxpr
pretty-print
Sign in to attempt this problem and view the solution.