medium primitives

Multiple Mutable Collections

Why this matters

Flax’s variable system isn’t limited to params. Variables can live in any collection name you invent. Each collection has its own lifecycle, its own mutability, its own RNG stream. The standard collections you’ve already seen:

  • params — trainable parameters; updated by the optimizer.
  • batch_stats — running mean/var for BatchNorm (pos 17).
  • cache — KV-cache for attention at decode time (pos 38).

But you can declare your own collection for any non-parameter state you want the model to keep around: counters, EMA shadow weights, debug histograms, replay buffers, gating temperatures, learned-but-not-gradient state. The mechanism is uniform:

self.variable("<collection>", "<var_name>", init_fn)

During apply, you tell Flax which collections are mutable via mutable=[...]. Multiple can be mutable at once — params

  • batch_stats + buffers is a perfectly normal training configuration.

Three streams in one apply

This problem builds a tiny model with THREE concerns:

  1. params — a nn.Dense(4) layer’s kernel and bias. Init’d from the params rng.
  2. dropoutnn.Dropout(rate) consumes a dropout rng at apply time when deterministic=False.
  3. buffers — a custom running_count that increments by 1 each apply call. Lives in a user-defined collection named "buffers".

All three coexist. Init handles all three; apply uses all three.

Declaring running_count

running_count = self.variable("buffers", "running_count",
                               lambda: jnp.array(0.0))

Three arguments to self.variable:

  1. The collection name ("buffers" — your choice).
  2. The variable name within the collection.
  3. An init function that returns the variable’s initial value.

Reading: running_count.value. Writing: running_count.value = ....

The is_initializing() guard

During model.init(...), Flax runs __call__ once with traced placeholder values just to allocate variables. If you write to a variable during init, you’d be writing to a placeholder — not what you want. The guard:

if not self.is_initializing():
    running_count.value = running_count.value + 1.0

This way the increment fires only during real apply calls, not during init.

Init with multiple RNGs

model.init accepts a single PRNGKey OR a dict of named RNGs:

init_rng, dropout_rng = jax.random.split(rng)
variables = model.init(
    {"params": init_rng, "dropout": dropout_rng},
    x,
)

Flax routes each RNG to the appropriate code path — the params rng seeds Dense’s kernel init; the dropout rng would seed any Dropout that fires during init (with deterministic=False, Dropout DOES fire — it just has nothing to drop the first time).

Apply with mutable + rngs

out, mutated = model.apply(
    variables,
    x,
    rngs={"dropout": dropout_rng},
    mutable=["buffers"],
)
  • rngs={"dropout": ...} provides the dropout RNG for this call.
  • mutable=["buffers"] lets the __call__ body write to the buffers collection.
  • The return is (out, mutated). mutated["buffers"]["running_count"] is the post-call value.

Note: params is immutable by default (you don’t list it in mutable). The optimizer mutates params outside apply, by receiving grads and producing updates. Other collections (like batch_stats or our buffers) update inside apply, so they must be listed.

Common pitfalls

  • Forgetting mutable=: if your __call__ writes to a collection that’s not in mutable=[...], Flax raises a clear error.
  • Forgetting the is_initializing() guard: writing to a variable during init usually fails or produces a phantom write that disappears.
  • Forgetting rngs={"dropout": ...} at apply: Dropout with deterministic=False requires a fresh dropout key each call.
  • Mixing up variables (the full dict) and params (just the params subtree): pass variables (the full dict) to apply, not just variables["params"]. The dropout/buffers collections live alongside params at the top level.

Problem

Build a tiny nn.Module (@nn.compact) with THREE pieces:

  • nn.Dense(4) (params).
  • nn.Dropout(rate=drop_rate, deterministic=False).
  • self.variable("buffers", "running_count", lambda: jnp.array(0.0)), incremented by 1 each call (guarded by is_initializing).

Init with both params and dropout rngs. Apply once with rngs={"dropout": ...} and mutable=["buffers"]. Return:

[running_count_after, output_sum]

running_count_after is 1.0 after a single apply (started at 0.0, incremented by 1.0). output_sum is jnp.sum(out) cast to a Python float.

Inputs:

  • seed: float (cast to int) — base PRNG seed.
  • x: 2-D (N, in_dim).
  • drop_rate: float — dropout probability.

Output: 1-D (2,).

Hints

flax variables collections

Sign in to attempt this problem and view the solution.