easy primitives

Pytree Leaves

Why this matters

A pytree in JAX is any nested Python container of “leaves” — Python dicts, lists, tuples, or Nones, whose innermost values are arrays (or scalars). Examples:

  • {"layer1": {"w": jnp.zeros((3, 4)), "b": jnp.zeros((4,))}, "layer2": {...}}
  • [(jnp.zeros((2,)), jnp.ones((3,))), {"extra": jnp.zeros((1,))}]
  • jnp.zeros((5,)) (a single leaf is a degenerate pytree)

JAX’s transformations operate on pytrees natively. jax.grad returns gradients in the same pytree shape as your params. jax.vmap and jax.tree_util.tree_map apply functions over each leaf. Optimizer state is a pytree mirroring your params. The entire JAX ecosystem — Optax, Flax, Equinox — is built around this abstraction.

Manipulating pytrees uses jax.tree_util:

  • tree_leaves(pytree) → flat list of leaves.
  • tree_map(fn, pytree) → apply fn to each leaf, preserve structure.
  • tree_reduce(fn, pytree, init) → fold over leaves.
  • tree_structure(pytree) → captures the structure for later reconstruction via tree_unflatten.

Worked mini-example

import jax
import jax.numpy as jnp

params = {
    "layer1": {"w": jnp.array([[1., 2.], [3., 4.]]), "b": jnp.array([5., 6.])},
    "layer2": {"w": jnp.array([[7., 8.]]),            "b": jnp.array([9.])},
}

leaves = jax.tree_util.tree_leaves(params)
# [array([[1,2],[3,4]]), array([5,6]), array([[7,8]]), array([9])]

total = sum(jnp.sum(l) for l in leaves)
# = (1+2+3+4) + (5+6) + (7+8) + (9) = 10 + 11 + 15 + 9 = 45

# Equivalent via tree_reduce:
import operator
total2 = jax.tree_util.tree_reduce(
    operator.add,
    jax.tree_util.tree_map(jnp.sum, params),
    0.0,
)

Common pitfalls

  • Mistaking a single array for “no pytree”: a lone jnp.array([1, 2, 3]) is a pytree with ONE leaf; it works with all tree utilities.
  • None as a leaf vs structure: None is treated as a “no-op” structure node by default, NOT as a leaf. This matters when masking optimizer steps.
  • Custom classes as nodes: by default, custom Python classes are LEAVES (opaque). Register them with jax.tree_util.register_pytree_node to make their attributes traversable.
  • tree_map shape requirement: when mapping over multiple pytrees together (tree_map(fn, t1, t2)), they must have IDENTICAL structure.

Problem

Implement sum_pytree_leaves(leaves). The function is given the leaves of a pytree as a flat 1-D array (we skip the extraction step to keep the test contract simple). Return the sum of all elements.

In real code you’d call jax.tree_util.tree_leaves(pytree) first to extract the flat list of leaves, then sum them. Here the pre-flattened array stands in for that output.

One illustrative example (not from the test set):

  • sum_pytree_leaves(jnp.array([1.0, 2.0, 3.0]))6.0

Hints

jax pytree tree-util

Sign in to attempt this problem and view the solution.