medium primitives

tree_flatten / tree_unflatten Round-Trip

Why this matters

jax.tree_util.tree_flatten(pytree) and its inverse jax.tree_util.tree_unflatten(treedef, leaves) are the low-level primitives that every JAX transform uses internally. Understanding them lets you:

  • Serialize / deserialize pytrees (save params to disk, reload).
  • Build custom optimizers that operate on flat leaf lists and then rebuild the original structure.
  • Debug custom pytree registrations by verifying the round-trip.
  • Inspect what JAX considers a β€œleaf” in your data structure.
import jax, jax.numpy as jnp

params = {"w": jnp.ones((2, 3)), "b": jnp.zeros(3)}

leaves, treedef = jax.tree_util.tree_flatten(params)
# leaves   β†’ [array([[1,1,1],[1,1,1]]), array([0,0,0])]
# treedef  β†’ PyTreeDef({'b': *, 'w': *}, ...)

rebuilt = jax.tree_util.tree_unflatten(treedef, leaves)
# rebuilt == params  (same structure, same leaves)

The treedef object encodes the container types, keys, and nesting order. It is opaque but hashable β€” JAX caches compilations keyed partly by treedef.

Worked mini-example

import jax, jax.numpy as jnp

arr = jnp.array([1., 2., 3., 4.])

# A single jax array is itself a (trivial) pytree with one leaf:
leaves, treedef = jax.tree_util.tree_flatten(arr)
# leaves   β†’ [array([1., 2., 3., 4.])]
# treedef  β†’ PyTreeDef(*)

rebuilt = jax.tree_util.tree_unflatten(treedef, leaves)
# rebuilt  β†’ array([1., 2., 3., 4.])

total = jnp.sum(rebuilt)
# total β†’ 10.0

Common pitfalls

  • Leaf order matters: tree_unflatten expects leaves in the same order that tree_flatten produced them. Shuffling leaves silently corrupts the structure.
  • treedef is opaque: don’t inspect its internals; use it only as the first argument to tree_unflatten.
  • Different treedefs are incompatible: you cannot unflatten leaves from one treedef using a different treedef β€” JAX will raise.
  • A single array is a valid pytree β€” its leaves_list has exactly one element (the array itself).

Problem

Implement flatten_then_unflatten_sum(packed_inputs) that:

  1. Calls jax.tree_util.tree_flatten(packed_inputs) to get (leaves, treedef).
  2. Calls jax.tree_util.tree_unflatten(treedef, leaves) to rebuild.
  3. Returns jnp.sum(rebuilt).

For a single JAX array (the degenerate pytree used in the tests), leaves has one element β€” the array itself β€” and the rebuilt structure is identical to the input. The sum is therefore jnp.sum(packed_inputs).

Returns: scalar.

Examples (not from the test set):

  • flatten_then_unflatten_sum(jnp.array([2., 3.])) β†’ 5.0
  • flatten_then_unflatten_sum(jnp.array([0., 0., 0.])) β†’ 0.0

Hints

jax pytree tree-flatten

Sign in to attempt this problem and view the solution.