medium primitives

Pretty-Printed jaxpr Length

Why this matters

JAX’s Jaxpr object ships two human-readable representations:

  • str(closed_jaxpr) — uses the ClosedJaxpr string format, which includes constant literals inline.
  • jaxpr.pretty_print() — the inner Jaxpr‘s dedicated pretty-printer, which produces the canonical { lambda ; ... let ... in (...) } format used in documentation and academic papers.

Understanding pretty_print() is useful when:

  • Comparing implementations side-by-side: two functions should produce the same primitive sequence regardless of Python-level differences.
  • Documenting compiled programs: the pretty_print format is stable enough for inclusion in papers and design docs.
  • Teaching JAX internals: the lambda-calculus style output maps directly to JAX’s functional IR semantics.

Worked mini-example

import jax
import jax.numpy as jnp

x = jnp.ones(3)
closed = jax.make_jaxpr(lambda x: jnp.sum(x ** 2))(x)
pp     = closed.jaxpr.pretty_print()
print(pp)
# { lambda ; a:f32[3]. let
#     b:f32[3] = integer_pow[y=2] a
#     c:f32[]  = reduce_sum[axes=(0,)] b
#   in (c,) }
print(len(pp))   # 126 (may vary with JAX version)

Common pitfalls

  • pretty_print() vs str(jaxpr): pretty_print() is called on the inner Jaxpr (closed.jaxpr.pretty_print()), while str() is called on the ClosedJaxpr directly. They may differ in whitespace and constant representation.
  • Version sensitivity: the exact string (and its length) can change between JAX releases. This problem computes the length at runtime so it is always self-consistent — don’t hard-code the expected value.
  • Shape independence (for fixed rank): for elementwise + reduce_sum the Jaxpr structure is shape-independent; a [3] and a [5] input produce the same pretty-printed length (only the type annotation changes).

Problem

Implement pp_jaxpr_len(x) that:

  1. Calls jax.make_jaxpr(lambda x: jnp.sum(x ** 2))(x) to trace the sum-of-squares function on x.
  2. Accesses the inner Jaxpr via .jaxpr.
  3. Calls .pretty_print() on it and returns the character length as a jnp.float32 scalar.
  • x: 1-D JAX array.

Returns: scalar (float32) — character length of the pretty-printed Jaxpr.

Hints

jax make-jaxpr pretty-print

Sign in to attempt this problem and view the solution.