We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
make_jaxpr Inspection
Why this matters
Before XLA compiles anything, JAX lowers your Python function into a JAXPR — JAX’s intermediate representation (IR). A JAXPR is a sequentially-ordered list of primitive equations that maps inputs to outputs with no Python control flow left behind.
jax.make_jaxpr(f)(args) runs the tracing phase without executing the
computation, and returns the JAXPR object. Printing it with str() gives
you a human-readable summary of every primitive JAX will hand to XLA:
{ lambda ; a:f32[3]. let
b:f32[3] = integer_pow[y=2] a
c:f32[] = reduce_sum[axes=(0,)] b
in (c,) }
This is indispensable for:
- Debugging: verify that your function traces the way you expect.
-
Performance tuning: spot unexpected ops (e.g.
gatherinstead ofslice) that might be slow on hardware. - Checking static vs. dynamic values: anything that appears as a concrete literal was evaluated at trace time (a potential bug if it should vary).
Worked mini-example
import jax
import jax.numpy as jnp
x = jnp.ones(3)
jaxpr = jax.make_jaxpr(lambda x: jnp.sum(x ** 2))(x)
print(str(jaxpr)) # human-readable IR
print(len(str(jaxpr))) # e.g. 126
Common pitfalls
-
Shape-dependent IR: the JAXPR (and its string length) depends on input
shape and dtype, not values. A
[3]input and a[4]input may produce different string lengths depending on the JAX version. - Version-dependent output: the exact string representation can change between JAX versions. This problem computes the length on the fly, so it is self-consistent within a runtime.
-
make_jaxpr traces eagerly: unlike
jax.jit, callingmake_jaxprruns Python immediately and builds the IR. There is no caching.
Problem
Implement jaxpr_string_length(x) that:
-
Calls
jax.make_jaxpr(lambda x: jnp.sum(x ** 2))(x)to obtain the JAXPR for the sum-of-squares function traced onx. -
Converts the JAXPR to its string representation with
str(jaxpr). -
Returns the length of that string as a
jnp.float32scalar.
-
x: 1-D JAX array.
Returns: scalar (float32) — float(len(str(jaxpr))).
Example (not from the test set):
-
jaxpr_string_length(jnp.array([1.0, 2.0]))returns the string-length of the JAXPR for that shape as a float.
Hints
Sign in to attempt this problem and view the solution.