medium primitives

Pytree Map

Why this matters

jax.tree_util.tree_map(fn, pytree) applies fn to every leaf of a pytree, returning a new pytree with the same structure and transformed leaves. This is the primitive for transforming neural-net parameters: scaling gradients, applying weight decay, mixing two parameter sets, casting precision, and more.

from jax import tree_util as jtu

params = {"layer1": {"w": jnp.zeros((3, 4)), "b": jnp.zeros((4,))}, ...}

# Multiply every leaf by 0.99 (e.g., for an EMA update):
damped = jtu.tree_map(lambda l: l * 0.99, params)
# `damped` has the SAME nested-dict structure as `params`.

# Combine two pytrees element-wise via tree_map of TWO args:
sum_params = jtu.tree_map(lambda a, b: a + b, params1, params2)

The flagship win is structural agnosticism: tree_map doesn’t care if your params are nested 2 levels deep or 20 — it just walks. This is what makes JAX’s transformations composable across arbitrary parameter layouts.

Worked mini-example

from jax import tree_util as jtu
import jax.numpy as jnp

pytree = {
    "w": jnp.array([[1., 2.], [3., 4.]]),
    "b": jnp.array([5., 6.]),
}

# Multiply every leaf by 2.0:
scaled = jtu.tree_map(lambda l: l * 2.0, pytree)
# scaled = {"w": [[2, 4], [6, 8]], "b": [10, 12]}

# No nested loops, no flatten/unflatten boilerplate.

Common pitfalls

  • Single-arg map: tree_map(fn, t) calls fn(leaf) per leaf.
  • Multi-arg map: tree_map(fn, t1, t2, ...) requires identical structure across all trees; fn(leaf1, leaf2, ...) per matched leaf.
  • fn must work on arrays (not Python lists or tuples) — leaves are JAX arrays.
  • Closures over Python state: fn is called at trace time. If it depends on a Python value, that value is baked into the trace.
  • None placeholders: tree_map skips None nodes by default — useful for masking gradients to “frozen” params.

Problem

Implement scale_all_leaves(leaves_concat, leaf_sizes, scale). The function receives a flat concatenated array in place of a real pytree (the test contract can’t deliver arbitrary nested dicts). The math is simply leaves_concat * scale — a scalar broadcast over the whole array.

The lesson is about tree_map, not the math. In real JAX code you would write:

from jax import tree_util as jtu
scaled_pytree = jtu.tree_map(lambda l: l * scale, pytree)

leaf_sizes tells you how many elements each original leaf contributed to leaves_concat — it is unused in the actual computation (broadcasting makes per-leaf splitting pointless), but it’s there to keep the signature honest about the pytree context.

Two illustrative examples (not from the test set):

  • scale_all_leaves(jnp.array([1., 2., 3.]), jnp.array([3.]), 0.5)[0.5, 1.0, 1.5]
  • scale_all_leaves(jnp.array([4., -2., 0.]), jnp.array([1., 2.]), 3.0)[12., -6., 0.]

Hints

jax pytree tree-map

Sign in to attempt this problem and view the solution.