We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Frozen Attribute
Why this matters
Not every array attribute on a Module is a parameter or even a Variable. Sometimes you just want to STAPLE an array onto the module — frozen constants, lookup tables loaded from disk, or per-architecture biases that must NEVER receive a gradient.
nnx supports this with a simple convention: plain jax.Array attributes
that are NOT wrapped in nnx.Param or nnx.Variable are visible to the
forward pass but skipped by the type-filtered state APIs:
-
model.frozen_biasworks in math (it’s just an array). -
nnx.state(model, nnx.Param)does NOT include it (no Param wrapper). -
nnx.state(model, nnx.Variable)does NOT include it (no Variable wrapper). -
Optax’s
optimizer.update(grads)will not touch it (no gradient flows to a frozen reference).
Using a plain attribute is the lightest-weight way to encode “this array is part of the model but I will manage it myself.”
Three flavors of array on an nnx module
self.kernel = nnx.Param(...) # trainable; updated by optimizer
self.running_mean = nnx.Variable(...) # mutable; not trainable; in state pytree under filters
self.frozen_bias = jnp.array(...) # plain attribute; ignored by Param/Variable filters
The plain attribute is what we want when:
- Loading a precomputed RoPE frequency table.
- Building an attention mask once and reusing it.
- Encoding a fixed positional embedding (sinusoidal) you’ll never train.
- Any “this is data, but I don’t want optax to mess with it.”
Worked example
class FrozenAttr(nnx.Module):
def __init__(self, in_features, out_features, frozen_bias_value, rngs):
key = rngs.params()
self.kernel = nnx.Param(
jax.random.normal(key, (in_features, out_features))
* (1.0 / jnp.sqrt(in_features))
)
# Plain jax.Array — NOT a nnx.Param or nnx.Variable
self.frozen_bias = jnp.ones((out_features,)) * frozen_bias_value
def __call__(self, x):
return x @ self.kernel + self.frozen_bias
model = FrozenAttr(in_features=3, out_features=4, frozen_bias_value=2.0,
rngs=nnx.Rngs(0))
# 1 leaf — only kernel; frozen_bias is invisible to the Param filter
params = nnx.state(model, nnx.Param)
print(len(jax.tree_util.tree_leaves(params))) # 1
# But the forward pass uses both
print(model(jnp.array([1.0, 2.0, 3.0])))
The key intuition: type-filtered state IS the parameter universe.
Optax sees only what nnx.state(model, nnx.Param) returns; anything
outside that view is “frozen by virtue of being invisible.”
Caveats / version notes
-
In Flax 0.12.x,
nnx.split(model)and unfilterednnx.state(model)may capture the plain array as a generic state leaf (it’s still part of the pytree for serialization). The Param filter is what gives you the “trainable subset.” -
For full immutability — meaning the value can’t even be reassigned —
you can use
dataclasses.field(init=False)or just don’t expose a setter. -
If you also need the attribute to be device-sharded or carried through
nnx.jit, prefer wrapping asnnx.Variableand accept that you have to type-filter it out before optimizer updates.
Common pitfalls
-
Wrapping a frozen array as
nnx.Param. Optax will train it. -
Plain attribute and
nnx.jitcomplication. Pure plain attributes are sometimes treated as “static” by jit (re-traced if changed). For the eager problems in this sub-sprint we don’t jit, so this isn’t an issue. -
Computing
jnp.linalg.normon a Variable wrapper. Usemodel.frozen_biasdirectly (it’s already an array, not a Variable). For a Param, usemodel.kernel.value.
Problem
Write frozen_kernel_unchanged(seed, x, features, frozen_bias_value):
-
Define
FrozenAttr(nnx.Module).__init__allocatesself.kernel = nnx.Param(...)shape(in_features, out_features), initnormal * (1/sqrt(in_features)). Then assignsself.frozen_bias = jnp.ones((out_features,)) * float(frozen_bias_value)as a PLAIN array (no wrapper). Forward:x @ self.kernel + self.frozen_bias. -
Build with
nnx.Rngs(int(seed)), instantiate the module (in_features=x.shape[-1],out_features=int(features),frozen_bias_value=float(frozen_bias_value)). -
Compute:
-
params = nnx.state(model, nnx.Param). Count its leaves (n_param_leaves). Should be 1 — only the kernel. -
kernel_norm = jnp.linalg.norm(model.kernel.value). -
frozen_norm = jnp.linalg.norm(model.frozen_bias).
-
-
Return
jnp.array([float(n_param_leaves), float(kernel_norm), float(frozen_norm)]).
Note n_param_leaves == 1 even though the model has TWO array attributes —
only the nnx.Param-wrapped one is filterable as a Param.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float). -
frozen_bias_value: float — fills the frozen_bias.
Output: length-3 array [1.0, kernel_norm, frozen_bias_norm].
Hints
Sign in to attempt this problem and view the solution.