We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Implement Dense from Scratch
Why this matters
nn.Dense is the workhorse of every modern model. Reimplementing it from
scratch — using self.param directly instead of calling the built-in — is
the cleanest way to internalize three things:
-
How
self.paramdeclares trainable variables. - The exact kernel/bias shape convention Flax uses (and why).
-
How
initfigures out the input feature size automatically from the traced input.
Once you’ve built Dense yourself, the rest of the layer-reimplementation
problems (Conv, LayerNorm, etc.) are straightforward — they’re all the same
pattern with different math.
self.param API
Inside an @nn.compact __call__, you declare a parameter with:
p = self.param("name", initializer, shape)
- name: string — the key in the params dict.
-
initializer: a callable
(key, shape, dtype) -> array. Use one fromflax.linen.initializers:-
nn.initializers.lecun_normal()for kernels (stddev = 1/√fan_in). -
nn.initializers.zerosfor biases (note: no parens — it’s already a callable). -
nn.initializers.ones,nn.initializers.normal(stddev), etc.
-
-
shape: tuple of ints. The function is called with this shape during
init.
Dense math, formally
Dense applies an affine map: y = x @ W + b. Flax’s convention:
-
Whas shape(in_features, out_features). -
bhas shape(out_features,). -
x @ Wdoes the matmul over the last axis.
For 1-D input x of shape (in,) and out_features=K:
-
Wis(in, K),bis(K,), output is(K,).
Worked implementation
class MyDense(nn.Module):
features: int
@nn.compact
def __call__(self, x):
kernel = self.param(
"kernel",
nn.initializers.lecun_normal(),
(x.shape[-1], self.features), # in_features inferred from input
)
bias = self.param(
"bias",
nn.initializers.zeros,
(self.features,),
)
return jnp.dot(x, kernel) + bias
Notice x.shape[-1] — at init, x is the example input passed to
model.init(key, x), so the kernel’s input dimension is inferred. This is
the canonical “lazy init” pattern in Flax: shapes come from the trace.
Why use lecun_normal?
lecun_normal draws weights from Normal(0, 1/√fan_in). This keeps the
pre-activation variance roughly 1 (assuming inputs have unit variance), which
avoids early-training divergence. It’s the Flax/JAX default for Dense and
Conv kernels.
For a tanh-style activation, glorot_normal (Xavier) is more appropriate; for
ReLU, he_normal (Kaiming) preserves variance through ReLUs. The “right”
initializer depends on what comes after — the framework defaults are
reasonable but not always optimal.
Common pitfalls
-
(self.features, x.shape[-1])instead of(x.shape[-1], self.features): Flax convention is(in, out). The wrong order doesn’t error at init but gives wrong shapes downstream. Memorize: input axis first. -
Calling
nn.initializers.zeros(): it’s already a callable — no parens.nn.initializers.zerosworks;nn.initializers.zeros()raises. -
Hardcoding
in_features: Don’t pass it as a Module attribute. Readx.shape[-1]from the trace; that’s the whole point of@nn.compact.
Problem
Implement MyDense(features) using self.param:
-
Kernel shape
(x.shape[-1], features), init withlecun_normal(). -
Bias shape
(features,), init withzeros. -
Forward:
jnp.dot(x, kernel) + bias.
The function:
-
Takes
seed,x,features. -
Inits with
PRNGKey(seed), applies tox, returns the output.
Output: 1-D array of length features.
Hints
Sign in to attempt this problem and view the solution.