medium primitives

Custom Parameter Initializer

Why initialization matters

The initial values of neural network weights determine whether training converges at all. Bad initialization causes two classic failure modes:

  1. Vanishing gradients: if weights are too small, the signal shrinks layer by layer during the backward pass, leaving early layers with near-zero gradients that never train.
  2. Exploding gradients: if weights are too large, activations and gradients blow up exponentially with depth.

Good initializers keep the variance of activations roughly constant across layers — so neither shrinking nor exploding — regardless of input size. This is the core insight behind LeCun, Xavier, and He initialization.

Signal propagation math (simplified)

Consider a linear layer y = Wx where x is a vector of i.i.d. inputs with variance σ²_x, and W is a matrix of i.i.d. weights with variance σ²_w. The variance of each output unit is:

Var(y_i) = n_in * σ²_w * σ²_x

To keep Var(y) = Var(x) (no amplification, no shrinkage), we need:

σ²_w = 1 / n_in

That’s LeCun normal: draw weights from N(0, 1/n_in), i.e. standard deviation sqrt(1/n_in). This is the default for Flax’s nn.Dense.

Xavier/Glorot normal targets the harmonic mean of fan-in and fan-out:

σ²_w = 2 / (n_in + n_out)

Good for tanh/sigmoid activations. He/Kaiming normal targets ReLU networks:

σ²_w = 2 / n_in   (accounts for ReLU zeroing half the activations)

Flax initializers API

Flax exposes initializers in flax.linen.initializers (aliased as nn.initializers). All initializers are factories — they return a function that takes (key, shape, dtype) and returns an array.

nn.initializers.lecun_normal()    # LeCun normal (default for Dense kernel)
nn.initializers.zeros             # note: no () — it's already a callable
nn.initializers.glorot_uniform()  # Xavier uniform
nn.initializers.glorot_normal()   # Xavier normal
nn.initializers.he_normal()       # He/Kaiming normal
nn.initializers.he_uniform()      # He/Kaiming uniform
nn.initializers.ones              # constant initializer
nn.initializers.normal(stddev=0.01)  # custom normal

Pass them to nn.Dense (or any Flax layer) as keyword arguments:

nn.Dense(
    features=self.features,
    kernel_init=nn.initializers.lecun_normal(),
    bias_init=nn.initializers.zeros,
)

Worked contrast example

import jax
import jax.numpy as jnp
import flax.linen as nn

class NormalInit(nn.Module):
    features: int
    @nn.compact
    def __call__(self, x):
        return nn.Dense(self.features,
            kernel_init=nn.initializers.normal(stddev=0.01))(x)

class LecunInit(nn.Module):
    features: int
    @nn.compact
    def __call__(self, x):
        return nn.Dense(self.features,
            kernel_init=nn.initializers.lecun_normal())(x)

key = jax.random.PRNGKey(0)
x = jnp.ones((64,))

p_small = NormalInit(features=64).init(key, x)
p_lecun = LecunInit(features=64).init(key, x)

k_small = p_small["params"]["Dense_0"]["kernel"]
k_lecun = p_lecun["params"]["Dense_0"]["kernel"]

print(f"small stddev kernel std: {k_small.std():.4f}")   # ~0.0100
print(f"lecun kernel std:         {k_lecun.std():.4f}")   # ~0.1250 = 1/sqrt(64)

The LeCun kernel has ~12.5x larger standard deviation — exactly 1/sqrt(n_in) as expected. For a 64-unit wide network this matters a lot.

Bias vs kernel convention

Biases are almost always initialized to zero. The intuition: a nonzero bias offset at init creates a “dead zone” — some neurons activate strongly even before seeing data. Zero bias lets the network start from a neutral position and learn the offset it needs.

Kernel initialization is where the real work is done.

Common pitfalls

  • Forgetting () on factory initializers: lecun_normal is a function, not a callable instance. You must call it: lecun_normal(). Passing lecun_normal (without ()) will fail at runtime. Note: zeros and ones are already callables (not factories), so they don’t need ().
  • Mismatching initializer to activation: using glorot_normal (designed for tanh/sigmoid) on a ReLU network works, but is suboptimal. For ReLU, prefer he_normal.
  • Ignoring dtype: most initializers produce float32 by default. If you’re using bfloat16 training, specify dtype explicitly.

Problem

Build a Dense Flax Module using @nn.compact with explicit initializers: lecun_normal() for the kernel and zeros for the bias. Init and apply.

Inputs:

  • seed: int (passed as float — cast to int).
  • x: 1-D JAX array.
  • features: int (cast to int).

Output: 1-D array of length features.

Hints

flax param initializer

Sign in to attempt this problem and view the solution.