We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jit with Pytree Input
Why this matters
jax.jit understands pytrees โ arbitrary nested structures of JAX
arrays (lists, tuples, dicts, and any registered custom node). This is why
you can write jit(loss_fn)(params, batch) where params is a nested
dict of weight matrices and batch is a tuple of (x, y) tensors, and
jit handles all of it transparently.
Jit flattens the pytree structure into a list of leaves (the actual JAX arrays), traces the function on those leaves, and re-assembles the output pytree. The structure itself becomes part of the jit cache key โ a different structure triggers a re-trace.
Understanding this unlocks real ML code: Optax optimiser states, Flax parameter dicts, and Haiku transform outputs are all pytrees that jit handles natively.
Worked mini-example
import jax
import jax.numpy as jnp
# params is a nested dict โ a pytree.
params = {
"w1": jnp.ones((4, 4)),
"b1": jnp.zeros((4,)),
}
@jax.jit
def forward(params, x):
return params["w1"] @ x + params["b1"]
x = jnp.ones((4,))
y = forward(params, x) # jit traces over the nested dict โ no problem.
Common pitfalls
-
Structure change โ cache miss: if you pass
{"w": ..., "b": ...}one call and{"b": ..., "w": ...}the next (different key order in Python < 3.7 or different registered pytree), jit re-traces. - Non-JAX leaves are static: Python scalars or strings inside the pytree are treated as static; changing them triggers a re-trace.
- This problem uses a flat array: JSON cannot deliver arbitrary nested dicts, so the test harness packs inputs into a flat array. The principle โ jit handles pytrees โ applies identically to nested structures.
Problem
Implement jit_pytree_sum(packed_inputs) that sums all elements of the
1-D input array using jax.jit(jnp.sum). This demonstrates that jit
compiles functions that accept pytree-structured inputs.
-
packed_inputs: 1-D JAX array.
Returns: scalar โ sum of all elements.
Example (not from the test set):
-
jit_pytree_sum(jnp.array([2.0, 3.0, 5.0]))returns10.0.
Hints
Sign in to attempt this problem and view the solution.