We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
tree_flatten / tree_unflatten Round-Trip
Why this matters
jax.tree_util.tree_flatten(pytree) and its inverse
jax.tree_util.tree_unflatten(treedef, leaves) are the low-level
primitives that every JAX transform uses internally. Understanding them
lets you:
- Serialize / deserialize pytrees (save params to disk, reload).
- Build custom optimizers that operate on flat leaf lists and then rebuild the original structure.
- Debug custom pytree registrations by verifying the round-trip.
- Inspect what JAX considers a βleafβ in your data structure.
import jax, jax.numpy as jnp
params = {"w": jnp.ones((2, 3)), "b": jnp.zeros(3)}
leaves, treedef = jax.tree_util.tree_flatten(params)
# leaves β [array([[1,1,1],[1,1,1]]), array([0,0,0])]
# treedef β PyTreeDef({'b': *, 'w': *}, ...)
rebuilt = jax.tree_util.tree_unflatten(treedef, leaves)
# rebuilt == params (same structure, same leaves)
The treedef object encodes the container types, keys, and nesting order.
It is opaque but hashable β JAX caches compilations keyed partly by treedef.
Worked mini-example
import jax, jax.numpy as jnp
arr = jnp.array([1., 2., 3., 4.])
# A single jax array is itself a (trivial) pytree with one leaf:
leaves, treedef = jax.tree_util.tree_flatten(arr)
# leaves β [array([1., 2., 3., 4.])]
# treedef β PyTreeDef(*)
rebuilt = jax.tree_util.tree_unflatten(treedef, leaves)
# rebuilt β array([1., 2., 3., 4.])
total = jnp.sum(rebuilt)
# total β 10.0
Common pitfalls
-
Leaf order matters:
tree_unflattenexpects leaves in the same order thattree_flattenproduced them. Shuffling leaves silently corrupts the structure. -
treedefis opaque: donβt inspect its internals; use it only as the first argument totree_unflatten. - Different treedefs are incompatible: you cannot unflatten leaves from one treedef using a different treedef β JAX will raise.
-
A single array is a valid pytree β its
leaves_listhas exactly one element (the array itself).
Problem
Implement flatten_then_unflatten_sum(packed_inputs) that:
-
Calls
jax.tree_util.tree_flatten(packed_inputs)to get(leaves, treedef). -
Calls
jax.tree_util.tree_unflatten(treedef, leaves)to rebuild. -
Returns
jnp.sum(rebuilt).
For a single JAX array (the degenerate pytree used in the tests),
leaves has one element β the array itself β and the rebuilt structure
is identical to the input. The sum is therefore jnp.sum(packed_inputs).
Returns: scalar.
Examples (not from the test set):
-
flatten_then_unflatten_sum(jnp.array([2., 3.]))β5.0 -
flatten_then_unflatten_sum(jnp.array([0., 0., 0.]))β0.0
Hints
Sign in to attempt this problem and view the solution.