medium primitives

NNX Frozen Attribute

Why this matters

Not every array attribute on a Module is a parameter or even a Variable. Sometimes you just want to STAPLE an array onto the module — frozen constants, lookup tables loaded from disk, or per-architecture biases that must NEVER receive a gradient.

nnx supports this with a simple convention: plain jax.Array attributes that are NOT wrapped in nnx.Param or nnx.Variable are visible to the forward pass but skipped by the type-filtered state APIs:

  • model.frozen_bias works in math (it’s just an array).
  • nnx.state(model, nnx.Param) does NOT include it (no Param wrapper).
  • nnx.state(model, nnx.Variable) does NOT include it (no Variable wrapper).
  • Optax’s optimizer.update(grads) will not touch it (no gradient flows to a frozen reference).

Using a plain attribute is the lightest-weight way to encode “this array is part of the model but I will manage it myself.”

Three flavors of array on an nnx module

self.kernel = nnx.Param(...)        # trainable; updated by optimizer
self.running_mean = nnx.Variable(...) # mutable; not trainable; in state pytree under filters
self.frozen_bias = jnp.array(...)   # plain attribute; ignored by Param/Variable filters

The plain attribute is what we want when:

  • Loading a precomputed RoPE frequency table.
  • Building an attention mask once and reusing it.
  • Encoding a fixed positional embedding (sinusoidal) you’ll never train.
  • Any “this is data, but I don’t want optax to mess with it.”

Worked example

class FrozenAttr(nnx.Module):
    def __init__(self, in_features, out_features, frozen_bias_value, rngs):
        key = rngs.params()
        self.kernel = nnx.Param(
            jax.random.normal(key, (in_features, out_features))
            * (1.0 / jnp.sqrt(in_features))
        )
        # Plain jax.Array — NOT a nnx.Param or nnx.Variable
        self.frozen_bias = jnp.ones((out_features,)) * frozen_bias_value

    def __call__(self, x):
        return x @ self.kernel + self.frozen_bias

model = FrozenAttr(in_features=3, out_features=4, frozen_bias_value=2.0,
                   rngs=nnx.Rngs(0))

# 1 leaf — only kernel; frozen_bias is invisible to the Param filter
params = nnx.state(model, nnx.Param)
print(len(jax.tree_util.tree_leaves(params)))     # 1

# But the forward pass uses both
print(model(jnp.array([1.0, 2.0, 3.0])))

The key intuition: type-filtered state IS the parameter universe. Optax sees only what nnx.state(model, nnx.Param) returns; anything outside that view is “frozen by virtue of being invisible.”

Caveats / version notes

  • In Flax 0.12.x, nnx.split(model) and unfiltered nnx.state(model) may capture the plain array as a generic state leaf (it’s still part of the pytree for serialization). The Param filter is what gives you the “trainable subset.”
  • For full immutability — meaning the value can’t even be reassigned — you can use dataclasses.field(init=False) or just don’t expose a setter.
  • If you also need the attribute to be device-sharded or carried through nnx.jit, prefer wrapping as nnx.Variable and accept that you have to type-filter it out before optimizer updates.

Common pitfalls

  • Wrapping a frozen array as nnx.Param. Optax will train it.
  • Plain attribute and nnx.jit complication. Pure plain attributes are sometimes treated as “static” by jit (re-traced if changed). For the eager problems in this sub-sprint we don’t jit, so this isn’t an issue.
  • Computing jnp.linalg.norm on a Variable wrapper. Use model.frozen_bias directly (it’s already an array, not a Variable). For a Param, use model.kernel.value.

Problem

Write frozen_kernel_unchanged(seed, x, features, frozen_bias_value):

  1. Define FrozenAttr(nnx.Module). __init__ allocates self.kernel = nnx.Param(...) shape (in_features, out_features), init normal * (1/sqrt(in_features)). Then assigns self.frozen_bias = jnp.ones((out_features,)) * float(frozen_bias_value) as a PLAIN array (no wrapper). Forward: x @ self.kernel + self.frozen_bias.
  2. Build with nnx.Rngs(int(seed)), instantiate the module (in_features=x.shape[-1], out_features=int(features), frozen_bias_value=float(frozen_bias_value)).
  3. Compute:
    • params = nnx.state(model, nnx.Param). Count its leaves (n_param_leaves). Should be 1 — only the kernel.
    • kernel_norm = jnp.linalg.norm(model.kernel.value).
    • frozen_norm = jnp.linalg.norm(model.frozen_bias).
  4. Return jnp.array([float(n_param_leaves), float(kernel_norm), float(frozen_norm)]).

Note n_param_leaves == 1 even though the model has TWO array attributes — only the nnx.Param-wrapped one is filterable as a Param.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).
  • frozen_bias_value: float — fills the frozen_bias.

Output: length-3 array [1.0, kernel_norm, frozen_bias_norm].

Hints

flax nnx frozen param-filter

Sign in to attempt this problem and view the solution.