We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Pytree Leaves
Why this matters
A pytree in JAX is any nested Python container of “leaves” — Python
dicts, lists, tuples, or Nones, whose innermost values are arrays (or
scalars). Examples:
-
{"layer1": {"w": jnp.zeros((3, 4)), "b": jnp.zeros((4,))}, "layer2": {...}} -
[(jnp.zeros((2,)), jnp.ones((3,))), {"extra": jnp.zeros((1,))}] -
jnp.zeros((5,))(a single leaf is a degenerate pytree)
JAX’s transformations operate on pytrees natively. jax.grad returns
gradients in the same pytree shape as your params. jax.vmap and
jax.tree_util.tree_map apply functions over each leaf. Optimizer state is
a pytree mirroring your params. The entire JAX ecosystem — Optax, Flax,
Equinox — is built around this abstraction.
Manipulating pytrees uses jax.tree_util:
-
tree_leaves(pytree)→ flat list of leaves. -
tree_map(fn, pytree)→ applyfnto each leaf, preserve structure. -
tree_reduce(fn, pytree, init)→ fold over leaves. -
tree_structure(pytree)→ captures the structure for later reconstruction viatree_unflatten.
Worked mini-example
import jax
import jax.numpy as jnp
params = {
"layer1": {"w": jnp.array([[1., 2.], [3., 4.]]), "b": jnp.array([5., 6.])},
"layer2": {"w": jnp.array([[7., 8.]]), "b": jnp.array([9.])},
}
leaves = jax.tree_util.tree_leaves(params)
# [array([[1,2],[3,4]]), array([5,6]), array([[7,8]]), array([9])]
total = sum(jnp.sum(l) for l in leaves)
# = (1+2+3+4) + (5+6) + (7+8) + (9) = 10 + 11 + 15 + 9 = 45
# Equivalent via tree_reduce:
import operator
total2 = jax.tree_util.tree_reduce(
operator.add,
jax.tree_util.tree_map(jnp.sum, params),
0.0,
)
Common pitfalls
-
Mistaking a single array for “no pytree”: a lone
jnp.array([1, 2, 3])is a pytree with ONE leaf; it works with all tree utilities. -
Noneas a leaf vs structure:Noneis treated as a “no-op” structure node by default, NOT as a leaf. This matters when masking optimizer steps. -
Custom classes as nodes: by default, custom Python classes are LEAVES
(opaque). Register them with
jax.tree_util.register_pytree_nodeto make their attributes traversable. -
tree_mapshape requirement: when mapping over multiple pytrees together (tree_map(fn, t1, t2)), they must have IDENTICAL structure.
Problem
Implement sum_pytree_leaves(leaves). The function is given the leaves of a
pytree as a flat 1-D array (we skip the extraction step to keep the test
contract simple). Return the sum of all elements.
In real code you’d call jax.tree_util.tree_leaves(pytree) first to extract
the flat list of leaves, then sum them. Here the pre-flattened array stands
in for that output.
One illustrative example (not from the test set):
-
sum_pytree_leaves(jnp.array([1.0, 2.0, 3.0]))→6.0
Hints
Sign in to attempt this problem and view the solution.