medium primitives

Implement Dense from Scratch

Why this matters

nn.Dense is the workhorse of every modern model. Reimplementing it from scratch — using self.param directly instead of calling the built-in — is the cleanest way to internalize three things:

  1. How self.param declares trainable variables.
  2. The exact kernel/bias shape convention Flax uses (and why).
  3. How init figures out the input feature size automatically from the traced input.

Once you’ve built Dense yourself, the rest of the layer-reimplementation problems (Conv, LayerNorm, etc.) are straightforward — they’re all the same pattern with different math.

self.param API

Inside an @nn.compact __call__, you declare a parameter with:

p = self.param("name", initializer, shape)
  • name: string — the key in the params dict.
  • initializer: a callable (key, shape, dtype) -> array. Use one from flax.linen.initializers:
    • nn.initializers.lecun_normal() for kernels (stddev = 1/√fan_in).
    • nn.initializers.zeros for biases (note: no parens — it’s already a callable).
    • nn.initializers.ones, nn.initializers.normal(stddev), etc.
  • shape: tuple of ints. The function is called with this shape during init.

Dense math, formally

Dense applies an affine map: y = x @ W + b. Flax’s convention:

  • W has shape (in_features, out_features).
  • b has shape (out_features,).
  • x @ W does the matmul over the last axis.

For 1-D input x of shape (in,) and out_features=K:

  • W is (in, K), b is (K,), output is (K,).

Worked implementation

class MyDense(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        kernel = self.param(
            "kernel",
            nn.initializers.lecun_normal(),
            (x.shape[-1], self.features),  # in_features inferred from input
        )
        bias = self.param(
            "bias",
            nn.initializers.zeros,
            (self.features,),
        )
        return jnp.dot(x, kernel) + bias

Notice x.shape[-1] — at init, x is the example input passed to model.init(key, x), so the kernel’s input dimension is inferred. This is the canonical “lazy init” pattern in Flax: shapes come from the trace.

Why use lecun_normal?

lecun_normal draws weights from Normal(0, 1/√fan_in). This keeps the pre-activation variance roughly 1 (assuming inputs have unit variance), which avoids early-training divergence. It’s the Flax/JAX default for Dense and Conv kernels.

For a tanh-style activation, glorot_normal (Xavier) is more appropriate; for ReLU, he_normal (Kaiming) preserves variance through ReLUs. The “right” initializer depends on what comes after — the framework defaults are reasonable but not always optimal.

Common pitfalls

  • (self.features, x.shape[-1]) instead of (x.shape[-1], self.features): Flax convention is (in, out). The wrong order doesn’t error at init but gives wrong shapes downstream. Memorize: input axis first.
  • Calling nn.initializers.zeros(): it’s already a callable — no parens. nn.initializers.zeros works; nn.initializers.zeros() raises.
  • Hardcoding in_features: Don’t pass it as a Module attribute. Read x.shape[-1] from the trace; that’s the whole point of @nn.compact.

Problem

Implement MyDense(features) using self.param:

  1. Kernel shape (x.shape[-1], features), init with lecun_normal().
  2. Bias shape (features,), init with zeros.
  3. Forward: jnp.dot(x, kernel) + bias.

The function:

  • Takes seed, x, features.
  • Inits with PRNGKey(seed), applies to x, returns the output.

Output: 1-D array of length features.

Hints

flax self-param dense

Sign in to attempt this problem and view the solution.