We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Pytree Aux Data
Why this matters
When registering a custom pytree node, the flatten_fn returns
(children, aux_data). The aux_data carries non-array metadata
that is preserved through jit, grad, and vmap β but is not
treated as a differentiable leaf.
Common uses of aux_data:
-
Layer names / configuration β Flax stores module hyperparameters
(e.g.
num_heads,dtype) as aux_data so recompilation is triggered only when the structure changes, not the weights. - Axis specifications β sharding specs that travel with the container.
-
Static flags β
use_bias: bool,activation: str.
import jax, jax.numpy as jnp
from typing import NamedTuple
class LayerParams(NamedTuple):
weight: jax.Array
bias: jax.Array
name: str # non-array β must live in aux_data
jax.tree_util.register_pytree_node(
LayerParams,
# flatten: separate arrays (children) from name (aux)
flatten_func=lambda lp: ([lp.weight, lp.bias], lp.name),
# unflatten: reconstruct from aux and children
unflatten_func=lambda name, children: LayerParams(*children, name),
)
After registration, jax.tree_util.tree_leaves(lp) returns only
[weight, bias] β name is hidden in the treedef, not exposed as a leaf.
Worked mini-example
import jax, jax.numpy as jnp
lp = LayerParams(
weight=jnp.array([1., 2., 3.]),
bias=jnp.array([0.1]),
name="linear_1",
)
leaves = jax.tree_util.tree_leaves(lp)
# β [array([1., 2., 3.]), array([0.1])]
# name = "linear_1" is in aux_data, not a leaf
scaled = jax.tree_util.tree_map(lambda l: l * 2.0, lp)
# β LayerParams(weight=[2,4,6], bias=[0.2], name="linear_1")
# aux_data (name) is preserved unchanged
Common pitfalls
-
aux_dataMUST be hashable β JAX hashes it to form the treedef cache key. Arrays, lists-containing-arrays, and other non-hashable objects will raise aTypeError. -
Changes to
aux_datatrigger a new treedef β ifnamechanges from"linear_1"to"linear_2", JAX will recompile the jitted function. Keep aux stable across calls. -
Do not put arrays in
aux_dataβ they wonβt be differentiated and wonβt be updated by optimizers. Move any trainable array tochildren. -
Noneis a validaux_dataβ use it when your container has no metadata.
Problem
Implement aux_data_demo(x, scale) that returns x * scale. This
demonstrates the mathematical primitive that sits at the heart of
aux_data-aware transforms: once the container is flattened and aux_data
is set aside, per-leaf operations (like scaling) apply uniformly.
In real JAX code, scale would live in aux_data (itβs a configuration
value, not a trainable parameter), and x would be a child leaf.
Returns: 1-D array β x * scale.
Examples (not from the test set):
-
aux_data_demo(jnp.array([1., 2.]), 3.0)β[3., 6.] -
aux_data_demo(jnp.array([5., -5.]), 0.0)β[0., 0.]
Hints
Sign in to attempt this problem and view the solution.