medium end_to_end

Jit with static_argnames

Why this matters

When you wrap a function with jax.jit, JAX traces it once using abstract “tracer” objects instead of real array values. Tracers carry shape and dtype information, but they carry no numeric values. This is what makes jit fast: the compiled XLA program runs identically for any input that has the same shape and dtype.

The problem arises when an argument’s value — not just its shape — must be known at trace time to determine the output shape. Consider tiling an array x a total of n times: the output has shape (x.shape[0] * n,). Under plain jax.jit, n becomes a tracer, and jnp.tile(x, n) has to figure out the output shape from a tracer value it cannot read. JAX raises a ConcretizationTypeError because the shape would only be known at runtime.

The fix is static_argnames. Declaring static_argnames=["n"] tells JAX: treat n as a Python constant at trace time. JAX bakes the concrete value of n into the compiled program, producing a specialized XLA binary for that specific value. If you call the jitted function with a different n, JAX re-traces and compiles a new binary.

The trade-off: every distinct value of n causes a new compilation. Use static_argnames only for arguments that are (a) needed at trace time to determine shapes or control-flow, and (b) drawn from a small set of values. A function like jnp.tile is perfect: you might call it with n=2, n=4, n=8 — a handful of cache entries and you are done.

Worked mini-example

import jax
import jax.numpy as jnp

# BREAKS — n is a Tracer under plain jit; output shape unknown at trace time
def repeat_v1(x, n):
    out = jnp.zeros((n * x.shape[0],))   # n is a tracer; shape error
    for i in range(n):                    # range(tracer) also errors
        out = out.at[i * x.shape[0]:(i + 1) * x.shape[0]].set(x)
    return out

jax.jit(repeat_v1)(jnp.array([1.0, 2.0]), 3)
# ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected

# Works — `n` marked static; JAX treats it as a Python int at trace time
def repeat_v2(x, n):
    return jnp.tile(x, n)

jit_repeat = jax.jit(repeat_v2, static_argnames=["n"])
jit_repeat(jnp.array([1.0, 2.0]), 3)   # → array([1., 2., 1., 2., 1., 2.])
jit_repeat(jnp.array([1.0, 2.0]), 5)   # → re-traces for n=5; fine

repeat_v1 fails because n is abstract — its value is not available when JAX tries to create a zeros array of size n * x.shape[0]. repeat_v2 works because jax.jit(..., static_argnames=["n"]) makes n concrete: JAX compiles a specialized binary for each distinct integer n.

Common pitfalls

  • Forgetting static_argnames for shape-controlling args: the first call without jit works fine; the first call with jit raises a ConcretizationTypeError. The error message will mention “abstract tracer value encountered where concrete value is expected.”
  • Marking too many args static: every distinct combination of static values triggers a separate compilation. If a static arg can take thousands of values, your cache grows unboundedly.
  • static_argnums=... is the positional-index equivalent of static_argnames. Both work; names are more refactor-safe.
  • Static args must be hashable: JAX uses the static arg values as cache keys. Python ints, strings, and tuples are fine. Lists and JAX arrays are not hashable and will raise a TypeError.
  • Passing n as a 0-D JAX array: jnp.tile(x, jnp.array(3)) looks harmless without jit, but under jit the 0-D array is a tracer, and jnp.tile cannot use a tracer as the repetition count. Cast with int(n).

Problem

Implement jit_with_static_n(x, n) that takes a 1-D JAX array x and an integer n (delivered as a float; cast it inside), and returns a 1-D array formed by tiling x exactly n times along axis 0.

The output has shape (x.shape[0] * n,).

Two illustrative examples (not from the test set):

  • x = jnp.array([7.0, 8.0]), n = 2[7., 8., 7., 8.]
  • x = jnp.array([0.0, 1.0, 2.0]), n = 3[0., 1., 2., 0., 1., 2., 0., 1., 2.]

JIT note: the function itself does not call jax.jit — the lesson is that when you wrap it yourself, you would write jax.jit(jit_with_static_n, static_argnames=["n"]) so that n is treated as a Python int at trace time. Without static_argnames, the jitted version would raise a ConcretizationTypeError.

Hints

jax jit static-argnames

Sign in to attempt this problem and view the solution.