medium primitives

jit with static_argnums

Why this matters

jax.jit traces a function by replacing every argument with an abstract tracer β€” a symbolic stand-in that knows shape and dtype but not value. This is usually what you want. But some arguments must be concrete Python values: anything that controls output shape (like the number of repeats in jnp.tile), a Python int used in a range(...) call, or an integer index into a list.

static_argnums lets you tell jit: β€œtreat these positional arguments as Python constants, not JAX tracers.” Jit will re-trace whenever a static argument changes value β€” which is fine if the argument only ever takes a handful of distinct values (e.g., n_repeats=2, n_repeats=4).

Worked mini-example

import jax
import jax.numpy as jnp

# jnp.tile's second arg is a Python int controlling output shape.
# Without static_argnums this would crash with an abstract tracer error.
f = jax.jit(jnp.tile, static_argnums=(1,))

x = jnp.array([1.0, 2.0, 3.0])
result = f(x, 3)        # β†’ [1, 2, 3, 1, 2, 3, 1, 2, 3]
result2 = f(x, 2)       # re-traces because n changed

Common pitfalls

  • Indices are 0-based: static_argnums=(1,) marks the second positional argument (index 1), not the first.
  • Float inputs: test harness passes n_repeats as a float; cast with int(n_repeats) before use.
  • Reordering args breaks static_argnums: if you later add a new first argument, index 1 now points to a different parameter. Prefer static_argnames (Problem 2) for that reason.
  • Every distinct static value triggers a re-trace: avoid passing user-supplied integers that span a large range β€” the jit cache grows without bound.

Problem

Implement jit_with_static_int(x, n_repeats) that tiles the 1-D array x along axis 0 exactly n_repeats times using jnp.tile, compiled with jax.jit using static_argnums=(1,).

  • x: 1-D JAX array.
  • n_repeats: passed as a float from the test harness; cast to int.

Returns: 1-D array of shape (x.shape[0] * n_repeats,).

Example (not from the test set):

  • jit_with_static_int(jnp.array([1.0, 2.0]), 3) returns [1.0, 2.0, 1.0, 2.0, 1.0, 2.0].

Hints

jax jit static-argnums

Sign in to attempt this problem and view the solution.