We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
register_pytree_node
Why this matters
JAXโs transforms โ jit, grad, vmap โ operate on pytrees: nested
containers of arrays. By default, JAX knows about Python lists, tuples, and
dicts. To use your own class (a NamedTuple, a dataclass, a custom node
type) with these transforms, you must register it as a pytree node.
import jax
from typing import NamedTuple
class Pair(NamedTuple):
first: jax.Array
second: jax.Array
# Tell JAX how to flatten and unflatten a Pair:
jax.tree_util.register_pytree_node(
Pair,
flatten_func=lambda p: ([p.first, p.second], None), # (children, aux)
unflatten_func=lambda aux, children: Pair(*children),
)
# Now JAX treats Pair as a pytree:
p = Pair(jnp.array([1., 2.]), jnp.array([3., 4.]))
leaves = jax.tree_util.tree_leaves(p)
# โ [array([1., 2.]), array([3., 4.])]
Once registered, jit, grad, and vmap can traverse Pair naturally โ
Flax, Optax, and most JAX libraries use this pattern internally.
Worked mini-example
import jax
import jax.numpy as jnp
class Pair(NamedTuple):
first: jax.Array
second: jax.Array
jax.tree_util.register_pytree_node(
Pair,
lambda p: ([p.first, p.second], None),
lambda aux, ch: Pair(*ch),
)
p = Pair(jnp.array([1., 2., 3.]), jnp.array([4., 5., 6.]))
total = sum(jnp.sum(leaf) for leaf in jax.tree_util.tree_leaves(p))
# total โ 21.0
Common pitfalls
-
flatten_fnmust return(children, aux_data)โchildrenare the array leaves;aux_datais non-array metadata (can beNone). -
unflatten_fnsignature is(aux_data, children)โ note the order is reversed fromflatten_fn. -
aux_datamust be hashable โ JAX uses it as part of the treedef key for cache invalidation. -
Register before first use โ calling
tree_leavesor any transform before registering will treat your class as an opaque leaf.
Problem
Implement sum_packed_named_pair(x_first, x_second) that returns the sum
of all elements across both arrays. This simplifies the registration
concept to its mathematical core: once a container is flattened into its
leaves, operations apply to each leaf independently.
Real code would register Pair and call jax.tree_util.tree_leaves(pair);
here the inputs arrive pre-flattened.
Returns: scalar โ sum(x_first) + sum(x_second).
Examples (not from the test set):
-
sum_packed_named_pair(jnp.array([1., 2.]), jnp.array([3., 4.]))โ10.0 -
sum_packed_named_pair(jnp.array([0.]), jnp.array([0.]))โ0.0
Hints
Sign in to attempt this problem and view the solution.