easy primitives

jit with Pytree Input

Why this matters

jax.jit understands pytrees โ€” arbitrary nested structures of JAX arrays (lists, tuples, dicts, and any registered custom node). This is why you can write jit(loss_fn)(params, batch) where params is a nested dict of weight matrices and batch is a tuple of (x, y) tensors, and jit handles all of it transparently.

Jit flattens the pytree structure into a list of leaves (the actual JAX arrays), traces the function on those leaves, and re-assembles the output pytree. The structure itself becomes part of the jit cache key โ€” a different structure triggers a re-trace.

Understanding this unlocks real ML code: Optax optimiser states, Flax parameter dicts, and Haiku transform outputs are all pytrees that jit handles natively.

Worked mini-example

import jax
import jax.numpy as jnp

# params is a nested dict โ€” a pytree.
params = {
    "w1": jnp.ones((4, 4)),
    "b1": jnp.zeros((4,)),
}

@jax.jit
def forward(params, x):
    return params["w1"] @ x + params["b1"]

x = jnp.ones((4,))
y = forward(params, x)   # jit traces over the nested dict โ€” no problem.

Common pitfalls

  • Structure change โ†’ cache miss: if you pass {"w": ..., "b": ...} one call and {"b": ..., "w": ...} the next (different key order in Python < 3.7 or different registered pytree), jit re-traces.
  • Non-JAX leaves are static: Python scalars or strings inside the pytree are treated as static; changing them triggers a re-trace.
  • This problem uses a flat array: JSON cannot deliver arbitrary nested dicts, so the test harness packs inputs into a flat array. The principle โ€” jit handles pytrees โ€” applies identically to nested structures.

Problem

Implement jit_pytree_sum(packed_inputs) that sums all elements of the 1-D input array using jax.jit(jnp.sum). This demonstrates that jit compiles functions that accept pytree-structured inputs.

  • packed_inputs: 1-D JAX array.

Returns: scalar โ€” sum of all elements.

Example (not from the test set):

  • jit_pytree_sum(jnp.array([2.0, 3.0, 5.0])) returns 10.0.

Hints

jax jit pytree

Sign in to attempt this problem and view the solution.