hard primitives

Pytree Aux Data

Why this matters

When registering a custom pytree node, the flatten_fn returns (children, aux_data). The aux_data carries non-array metadata that is preserved through jit, grad, and vmap β€” but is not treated as a differentiable leaf.

Common uses of aux_data:

  • Layer names / configuration β€” Flax stores module hyperparameters (e.g. num_heads, dtype) as aux_data so recompilation is triggered only when the structure changes, not the weights.
  • Axis specifications β€” sharding specs that travel with the container.
  • Static flags β€” use_bias: bool, activation: str.
import jax, jax.numpy as jnp
from typing import NamedTuple

class LayerParams(NamedTuple):
    weight: jax.Array
    bias:   jax.Array
    name:   str   # non-array β€” must live in aux_data

jax.tree_util.register_pytree_node(
    LayerParams,
    # flatten: separate arrays (children) from name (aux)
    flatten_func=lambda lp: ([lp.weight, lp.bias], lp.name),
    # unflatten: reconstruct from aux and children
    unflatten_func=lambda name, children: LayerParams(*children, name),
)

After registration, jax.tree_util.tree_leaves(lp) returns only [weight, bias] β€” name is hidden in the treedef, not exposed as a leaf.

Worked mini-example

import jax, jax.numpy as jnp

lp = LayerParams(
    weight=jnp.array([1., 2., 3.]),
    bias=jnp.array([0.1]),
    name="linear_1",
)

leaves = jax.tree_util.tree_leaves(lp)
# β†’ [array([1., 2., 3.]), array([0.1])]
# name = "linear_1" is in aux_data, not a leaf

scaled = jax.tree_util.tree_map(lambda l: l * 2.0, lp)
# β†’ LayerParams(weight=[2,4,6], bias=[0.2], name="linear_1")
# aux_data (name) is preserved unchanged

Common pitfalls

  • aux_data MUST be hashable β€” JAX hashes it to form the treedef cache key. Arrays, lists-containing-arrays, and other non-hashable objects will raise a TypeError.
  • Changes to aux_data trigger a new treedef β€” if name changes from "linear_1" to "linear_2", JAX will recompile the jitted function. Keep aux stable across calls.
  • Do not put arrays in aux_data β€” they won’t be differentiated and won’t be updated by optimizers. Move any trainable array to children.
  • None is a valid aux_data β€” use it when your container has no metadata.

Problem

Implement aux_data_demo(x, scale) that returns x * scale. This demonstrates the mathematical primitive that sits at the heart of aux_data-aware transforms: once the container is flattened and aux_data is set aside, per-leaf operations (like scaling) apply uniformly.

In real JAX code, scale would live in aux_data (it’s a configuration value, not a trainable parameter), and x would be a child leaf.

Returns: 1-D array β€” x * scale.

Examples (not from the test set):

  • aux_data_demo(jnp.array([1., 2.]), 3.0) β†’ [3., 6.]
  • aux_data_demo(jnp.array([5., -5.]), 0.0) β†’ [0., 0.]

Hints

jax pytree aux-data

Sign in to attempt this problem and view the solution.