We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Multiple Mutable Collections
Why this matters
Flax’s variable system isn’t limited to params. Variables can
live in any collection name you invent. Each collection has its
own lifecycle, its own mutability, its own RNG stream. The
standard collections you’ve already seen:
-
params— trainable parameters; updated by the optimizer. -
batch_stats— running mean/var for BatchNorm (pos 17). -
cache— KV-cache for attention at decode time (pos 38).
But you can declare your own collection for any non-parameter state you want the model to keep around: counters, EMA shadow weights, debug histograms, replay buffers, gating temperatures, learned-but-not-gradient state. The mechanism is uniform:
self.variable("<collection>", "<var_name>", init_fn)
During apply, you tell Flax which collections are mutable
via mutable=[...]. Multiple can be mutable at once — params
-
batch_stats+buffersis a perfectly normal training configuration.
Three streams in one apply
This problem builds a tiny model with THREE concerns:
-
params— ann.Dense(4)layer’s kernel and bias. Init’d from theparamsrng. -
dropout—nn.Dropout(rate)consumes adropoutrng at apply time whendeterministic=False. -
buffers— a customrunning_countthat increments by 1 eachapplycall. Lives in a user-defined collection named"buffers".
All three coexist. Init handles all three; apply uses all three.
Declaring running_count
running_count = self.variable("buffers", "running_count",
lambda: jnp.array(0.0))
Three arguments to self.variable:
-
The collection name (
"buffers"— your choice). - The variable name within the collection.
- An init function that returns the variable’s initial value.
Reading: running_count.value. Writing: running_count.value = ....
The is_initializing() guard
During model.init(...), Flax runs __call__ once with traced
placeholder values just to allocate variables. If you write to a
variable during init, you’d be writing to a placeholder — not
what you want. The guard:
if not self.is_initializing():
running_count.value = running_count.value + 1.0
This way the increment fires only during real apply calls,
not during init.
Init with multiple RNGs
model.init accepts a single PRNGKey OR a dict of named RNGs:
init_rng, dropout_rng = jax.random.split(rng)
variables = model.init(
{"params": init_rng, "dropout": dropout_rng},
x,
)
Flax routes each RNG to the appropriate code path — the params
rng seeds Dense’s kernel init; the dropout rng would seed any
Dropout that fires during init (with deterministic=False,
Dropout DOES fire — it just has nothing to drop the first time).
Apply with mutable + rngs
out, mutated = model.apply(
variables,
x,
rngs={"dropout": dropout_rng},
mutable=["buffers"],
)
-
rngs={"dropout": ...}provides the dropout RNG for this call. -
mutable=["buffers"]lets the__call__body write to thebufferscollection. -
The return is
(out, mutated).mutated["buffers"]["running_count"]is the post-call value.
Note: params is immutable by default (you don’t list it in
mutable). The optimizer mutates params outside apply, by
receiving grads and producing updates. Other collections (like
batch_stats or our buffers) update inside apply, so they
must be listed.
Common pitfalls
-
Forgetting
mutable=: if your__call__writes to a collection that’s not inmutable=[...], Flax raises a clear error. -
Forgetting the
is_initializing()guard: writing to a variable during init usually fails or produces a phantom write that disappears. -
Forgetting
rngs={"dropout": ...}at apply: Dropout withdeterministic=Falserequires a fresh dropout key each call. -
Mixing up
variables(the full dict) andparams(just the params subtree): passvariables(the full dict) toapply, not justvariables["params"]. The dropout/buffers collections live alongside params at the top level.
Problem
Build a tiny nn.Module (@nn.compact) with THREE pieces:
-
nn.Dense(4)(params). -
nn.Dropout(rate=drop_rate, deterministic=False). -
self.variable("buffers", "running_count", lambda: jnp.array(0.0)), incremented by 1 each call (guarded byis_initializing).
Init with both params and dropout rngs. Apply once with
rngs={"dropout": ...} and mutable=["buffers"]. Return:
[running_count_after, output_sum]
running_count_after is 1.0 after a single apply (started at
0.0, incremented by 1.0). output_sum is jnp.sum(out) cast
to a Python float.
Inputs:
-
seed: float (cast to int) — base PRNG seed. -
x: 2-D(N, in_dim). -
drop_rate: float — dropout probability.
Output: 1-D (2,).
Hints
Sign in to attempt this problem and view the solution.