medium primitives

register_pytree_node

Why this matters

JAXโ€™s transforms โ€” jit, grad, vmap โ€” operate on pytrees: nested containers of arrays. By default, JAX knows about Python lists, tuples, and dicts. To use your own class (a NamedTuple, a dataclass, a custom node type) with these transforms, you must register it as a pytree node.

import jax
from typing import NamedTuple

class Pair(NamedTuple):
    first: jax.Array
    second: jax.Array

# Tell JAX how to flatten and unflatten a Pair:
jax.tree_util.register_pytree_node(
    Pair,
    flatten_func=lambda p: ([p.first, p.second], None),   # (children, aux)
    unflatten_func=lambda aux, children: Pair(*children),
)

# Now JAX treats Pair as a pytree:
p = Pair(jnp.array([1., 2.]), jnp.array([3., 4.]))
leaves = jax.tree_util.tree_leaves(p)
# โ†’ [array([1., 2.]), array([3., 4.])]

Once registered, jit, grad, and vmap can traverse Pair naturally โ€” Flax, Optax, and most JAX libraries use this pattern internally.

Worked mini-example

import jax
import jax.numpy as jnp

class Pair(NamedTuple):
    first: jax.Array
    second: jax.Array

jax.tree_util.register_pytree_node(
    Pair,
    lambda p: ([p.first, p.second], None),
    lambda aux, ch: Pair(*ch),
)

p = Pair(jnp.array([1., 2., 3.]), jnp.array([4., 5., 6.]))
total = sum(jnp.sum(leaf) for leaf in jax.tree_util.tree_leaves(p))
# total โ†’ 21.0

Common pitfalls

  • flatten_fn must return (children, aux_data) โ€” children are the array leaves; aux_data is non-array metadata (can be None).
  • unflatten_fn signature is (aux_data, children) โ€” note the order is reversed from flatten_fn.
  • aux_data must be hashable โ€” JAX uses it as part of the treedef key for cache invalidation.
  • Register before first use โ€” calling tree_leaves or any transform before registering will treat your class as an opaque leaf.

Problem

Implement sum_packed_named_pair(x_first, x_second) that returns the sum of all elements across both arrays. This simplifies the registration concept to its mathematical core: once a container is flattened into its leaves, operations apply to each leaf independently.

Real code would register Pair and call jax.tree_util.tree_leaves(pair); here the inputs arrive pre-flattened.

Returns: scalar โ€” sum(x_first) + sum(x_second).

Examples (not from the test set):

  • sum_packed_named_pair(jnp.array([1., 2.]), jnp.array([3., 4.])) โ†’ 10.0
  • sum_packed_named_pair(jnp.array([0.]), jnp.array([0.])) โ†’ 0.0

Hints

jax pytree register-pytree-node

Sign in to attempt this problem and view the solution.