We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
flax.struct.dataclass — Pytree-Friendly State
Why this matters
Stateful training loops have to bundle a lot of stuff: parameters,
optimizer state, step counter, EMA copies, RNG keys, dropout masks…
a Python dict works at first but breaks down fast — jax.jit
needs hashable container types, jax.tree_util.tree_map needs to
know which fields are leaves vs. metadata, and Optax/flax.training
expect immutable updates.
flax.struct.dataclass is the answer: a JAX-aware @dataclass
decorator that:
-
Registers your class with
jax.tree_utilso it’s a pytree. - Makes the instance immutable (no field assignment).
-
Adds a
replace(**kwargs)method for functional updates. -
Lets you mark fields as
pytree_node=False(treated as static metadata, not as leaves).
TrainState itself is built on top of flax.struct.dataclass.
Knowing how it works under the hood lets you build your own
custom state classes for novel training setups (PPO with separate
actor/critic state, EMA for diffusion, etc).
Defining one
from flax import struct
from typing import Any
import jax.numpy as jnp
@struct.dataclass
class MyState:
weights: jnp.ndarray
step: int = 0
Now MyState(weights=w, step=0):
-
Is immutable:
state.step = 5raisesdataclasses.FrozenInstanceError. -
Has
state.replace(step=5)for functional updates — returns a newMyState, leavesstateuntouched. -
Is a pytree:
jax.tree_util.tree_map(lambda x: x * 2, state)doubles every leaf (theweightsarray ANDstep).
Pytree leaves vs. static fields
By default, every field is a pytree leaf. That’s usually wrong for
things like learning rate schedules or boolean flags — they’re
metadata, not arrays. Use pytree_node=False:
@struct.dataclass
class TrainCfg:
params: Any
step: int = 0
use_ema: bool = struct.field(pytree_node=False, default=False)
Now tree_map skips use_ema (treated as static); jit recompiles
if use_ema changes; the field doesn’t appear in jax.tree_util.tree_leaves.
Conceptually:
- leaves = arrays the JIT sees as runtime values.
- static = Python values baked into the trace; changing them retraces the function.
The replace pattern
Because the dataclass is frozen, the canonical update is:
state = state.replace(step=state.step + 1)
This is not mutation — it constructs a new instance with all
other fields copied verbatim and just step overridden. The old
state is unchanged; if anyone else still holds a reference to
it, they see the old value.
This is the same pattern as Elixir/Erlang structs, Clojure records,
Rust’s .. spread — functional update.
Why not just dict?
Dicts are pytrees too — jax.tree_util walks them fine. The
advantages of struct.dataclass:
-
Type checking:
state.steps(typo) is aAttributeErrorimmediately, vs.state["steps"]returningKeyErroronly when accessed. - Hashability: a frozen dataclass with hashable fields hashes to a stable value; jit’s static-cache works on it.
- Documented schema: anyone reading the class sees the fields and their types in one place.
-
Methods: you can add helper methods (
apply_gradients, etc).
flax.training.train_state.TrainState uses all four advantages.
Common pitfalls
-
pytree_node=Falseand JIT recompilation: marking a field static means every distinct value causes recompilation. Don’t put a step counter aspytree_node=Falseunless you really want to recompile every step. -
Default mutable values:
weights: list = []is a Python footgun. Use afield(default_factory=list)if you absolutely must — but for JAX you almost always want an array, not a list. -
replaceaccepts only existing fields:state.replace(typo=5)raisesTypeError. This is a feature. -
Pickle pitfalls:
flax.struct.dataclassinstances pickle, but the class definition needs to be importable on the unpickling end.
Problem
Define a flax.struct.dataclass named MyState with two fields:
-
weights: an array. -
step: an integer (default 0).
Then:
-
Construct
instance = MyState(weights=jnp.asarray(params_a), step=0). (Theparams_bargument is unused — it’s there for signature consistency with related exercises.) -
Update via
instance = instance.replace(step=instance.step + 1). -
Return
[float(instance.weights.sum()), float(instance.step)]as a 1-D(2,)array.
After one replace(step=...) call, instance.step == 1.
Inputs:
-
seed: float — unused (passed for consistency). -
params_a: array — used asweights. -
params_b: array — unused.
Output: 1-D (2,) — [weights.sum(), step].
Hints
Sign in to attempt this problem and view the solution.