medium primitives

make_jaxpr Inspection

Why this matters

Before XLA compiles anything, JAX lowers your Python function into a JAXPR — JAX’s intermediate representation (IR). A JAXPR is a sequentially-ordered list of primitive equations that maps inputs to outputs with no Python control flow left behind.

jax.make_jaxpr(f)(args) runs the tracing phase without executing the computation, and returns the JAXPR object. Printing it with str() gives you a human-readable summary of every primitive JAX will hand to XLA:

{ lambda ; a:f32[3]. let
    b:f32[3] = integer_pow[y=2] a
    c:f32[] = reduce_sum[axes=(0,)] b
  in (c,) }

This is indispensable for:

  • Debugging: verify that your function traces the way you expect.
  • Performance tuning: spot unexpected ops (e.g. gather instead of slice) that might be slow on hardware.
  • Checking static vs. dynamic values: anything that appears as a concrete literal was evaluated at trace time (a potential bug if it should vary).

Worked mini-example

import jax
import jax.numpy as jnp

x = jnp.ones(3)
jaxpr = jax.make_jaxpr(lambda x: jnp.sum(x ** 2))(x)
print(str(jaxpr))        # human-readable IR
print(len(str(jaxpr)))   # e.g. 126

Common pitfalls

  • Shape-dependent IR: the JAXPR (and its string length) depends on input shape and dtype, not values. A [3] input and a [4] input may produce different string lengths depending on the JAX version.
  • Version-dependent output: the exact string representation can change between JAX versions. This problem computes the length on the fly, so it is self-consistent within a runtime.
  • make_jaxpr traces eagerly: unlike jax.jit, calling make_jaxpr runs Python immediately and builds the IR. There is no caching.

Problem

Implement jaxpr_string_length(x) that:

  1. Calls jax.make_jaxpr(lambda x: jnp.sum(x ** 2))(x) to obtain the JAXPR for the sum-of-squares function traced on x.
  2. Converts the JAXPR to its string representation with str(jaxpr).
  3. Returns the length of that string as a jnp.float32 scalar.
  • x: 1-D JAX array.

Returns: scalar (float32) — float(len(str(jaxpr))).

Example (not from the test set):

  • jaxpr_string_length(jnp.array([1.0, 2.0])) returns the string-length of the JAXPR for that shape as a float.

Hints

jax jaxpr ir

Sign in to attempt this problem and view the solution.