We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jit with static_argnames
Why this matters
static_argnums binds static-ness to positions in the argument list.
That breaks silently whenever you refactor and reorder parameters.
static_argnames is the safer, more explicit alternative: you name the
argument by its keyword rather than its index. As long as the name stays
the same, the static binding survives any reordering of positional args.
In real JAX codebases you’ll see static_argnames used heavily with
training-step functions where one argument (e.g., train=True/False)
switches behaviour at compile time rather than runtime.
Worked mini-example
import jax
import jax.numpy as jnp
def my_tile(x, n):
return jnp.tile(x, n)
# Bind by name — safe against arg reordering.
f = jax.jit(my_tile, static_argnames=("n",))
x = jnp.array([1.0, 2.0])
result = f(x, n=3) # → [1, 2, 1, 2, 1, 2]
result2 = f(x, n=2) # re-traces
Common pitfalls
-
Name must match the inner function’s parameter: if you wrap a lambda
or a function whose arg is called
reps, thenstatic_argnames=("n",)won’t work — you must use("reps",). -
Case-sensitive:
"N"≠"n". -
Float inputs: test harness passes
nas a float; cast withint(n)before passing tojnp.tile. -
You still need to call with the keyword:
f(x, n=3)works;f(x, 3)works too (positional), but the name binding is determined at function-definition time, not call time.
Problem
Implement jit_with_named_static(x, n) that tiles x exactly n times
along axis 0 using jax.jit with static_argnames=("n",).
-
x: 1-D JAX array. -
n: passed as a float from the test harness; cast toint.
Returns: 1-D array of shape (x.shape[0] * n,).
Example (not from the test set):
-
jit_with_named_static(jnp.array([5.0, 6.0]), 2.0)returns[5.0, 6.0, 5.0, 6.0].
Hints
Sign in to attempt this problem and view the solution.