We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Jit with static_argnames
Why this matters
When you wrap a function with jax.jit, JAX traces it once using abstract
“tracer” objects instead of real array values. Tracers carry shape and dtype
information, but they carry no numeric values. This is what makes jit
fast: the compiled XLA program runs identically for any input that has the
same shape and dtype.
The problem arises when an argument’s value — not just its shape — must
be known at trace time to determine the output shape. Consider tiling an
array x a total of n times: the output has shape (x.shape[0] * n,).
Under plain jax.jit, n becomes a tracer, and jnp.tile(x, n) has to
figure out the output shape from a tracer value it cannot read. JAX raises a
ConcretizationTypeError because the shape would only be known at runtime.
The fix is static_argnames. Declaring static_argnames=["n"] tells JAX:
treat n as a Python constant at trace time. JAX bakes the concrete value
of n into the compiled program, producing a specialized XLA binary for that
specific value. If you call the jitted function with a different n, JAX
re-traces and compiles a new binary.
The trade-off: every distinct value of n causes a new compilation. Use
static_argnames only for arguments that are (a) needed at trace time to
determine shapes or control-flow, and (b) drawn from a small set of values.
A function like jnp.tile is perfect: you might call it with n=2, n=4,
n=8 — a handful of cache entries and you are done.
Worked mini-example
import jax
import jax.numpy as jnp
# BREAKS — n is a Tracer under plain jit; output shape unknown at trace time
def repeat_v1(x, n):
out = jnp.zeros((n * x.shape[0],)) # n is a tracer; shape error
for i in range(n): # range(tracer) also errors
out = out.at[i * x.shape[0]:(i + 1) * x.shape[0]].set(x)
return out
jax.jit(repeat_v1)(jnp.array([1.0, 2.0]), 3)
# ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected
# Works — `n` marked static; JAX treats it as a Python int at trace time
def repeat_v2(x, n):
return jnp.tile(x, n)
jit_repeat = jax.jit(repeat_v2, static_argnames=["n"])
jit_repeat(jnp.array([1.0, 2.0]), 3) # → array([1., 2., 1., 2., 1., 2.])
jit_repeat(jnp.array([1.0, 2.0]), 5) # → re-traces for n=5; fine
repeat_v1 fails because n is abstract — its value is not available when
JAX tries to create a zeros array of size n * x.shape[0]. repeat_v2
works because jax.jit(..., static_argnames=["n"]) makes n concrete: JAX
compiles a specialized binary for each distinct integer n.
Common pitfalls
-
Forgetting
static_argnamesfor shape-controlling args: the first call without jit works fine; the first call with jit raises aConcretizationTypeError. The error message will mention “abstract tracer value encountered where concrete value is expected.” - Marking too many args static: every distinct combination of static values triggers a separate compilation. If a static arg can take thousands of values, your cache grows unboundedly.
-
static_argnums=...is the positional-index equivalent ofstatic_argnames. Both work; names are more refactor-safe. -
Static args must be hashable: JAX uses the static arg values as cache
keys. Python ints, strings, and tuples are fine. Lists and JAX arrays are
not hashable and will raise a
TypeError. -
Passing
nas a 0-D JAX array:jnp.tile(x, jnp.array(3))looks harmless without jit, but under jit the 0-D array is a tracer, andjnp.tilecannot use a tracer as the repetition count. Cast withint(n).
Problem
Implement jit_with_static_n(x, n) that takes a 1-D JAX array x and an
integer n (delivered as a float; cast it inside), and returns a 1-D array
formed by tiling x exactly n times along axis 0.
The output has shape (x.shape[0] * n,).
Two illustrative examples (not from the test set):
-
x = jnp.array([7.0, 8.0]),n = 2→[7., 8., 7., 8.] -
x = jnp.array([0.0, 1.0, 2.0]),n = 3→[0., 1., 2., 0., 1., 2., 0., 1., 2.]
JIT note: the function itself does not call jax.jit — the lesson is
that when you wrap it yourself, you would write
jax.jit(jit_with_static_n, static_argnames=["n"]) so that n is treated
as a Python int at trace time. Without static_argnames, the jitted version
would raise a ConcretizationTypeError.
Hints
Sign in to attempt this problem and view the solution.