easy primitives

jit with static_argnames

Why this matters

static_argnums binds static-ness to positions in the argument list. That breaks silently whenever you refactor and reorder parameters. static_argnames is the safer, more explicit alternative: you name the argument by its keyword rather than its index. As long as the name stays the same, the static binding survives any reordering of positional args.

In real JAX codebases you’ll see static_argnames used heavily with training-step functions where one argument (e.g., train=True/False) switches behaviour at compile time rather than runtime.

Worked mini-example

import jax
import jax.numpy as jnp

def my_tile(x, n):
    return jnp.tile(x, n)

# Bind by name — safe against arg reordering.
f = jax.jit(my_tile, static_argnames=("n",))

x = jnp.array([1.0, 2.0])
result = f(x, n=3)    # → [1, 2, 1, 2, 1, 2]
result2 = f(x, n=2)   # re-traces

Common pitfalls

  • Name must match the inner function’s parameter: if you wrap a lambda or a function whose arg is called reps, then static_argnames=("n",) won’t work — you must use ("reps",).
  • Case-sensitive: "N""n".
  • Float inputs: test harness passes n as a float; cast with int(n) before passing to jnp.tile.
  • You still need to call with the keyword: f(x, n=3) works; f(x, 3) works too (positional), but the name binding is determined at function-definition time, not call time.

Problem

Implement jit_with_named_static(x, n) that tiles x exactly n times along axis 0 using jax.jit with static_argnames=("n",).

  • x: 1-D JAX array.
  • n: passed as a float from the test harness; cast to int.

Returns: 1-D array of shape (x.shape[0] * n,).

Example (not from the test set):

  • jit_with_named_static(jnp.array([5.0, 6.0]), 2.0) returns [5.0, 6.0, 5.0, 6.0].

Hints

jax jit static-argnames

Sign in to attempt this problem and view the solution.