We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Custom Parameter Initializer
Why initialization matters
The initial values of neural network weights determine whether training converges at all. Bad initialization causes two classic failure modes:
- Vanishing gradients: if weights are too small, the signal shrinks layer by layer during the backward pass, leaving early layers with near-zero gradients that never train.
- Exploding gradients: if weights are too large, activations and gradients blow up exponentially with depth.
Good initializers keep the variance of activations roughly constant across layers — so neither shrinking nor exploding — regardless of input size. This is the core insight behind LeCun, Xavier, and He initialization.
Signal propagation math (simplified)
Consider a linear layer y = Wx where x is a vector of i.i.d. inputs with
variance σ²_x, and W is a matrix of i.i.d. weights with variance σ²_w.
The variance of each output unit is:
Var(y_i) = n_in * σ²_w * σ²_x
To keep Var(y) = Var(x) (no amplification, no shrinkage), we need:
σ²_w = 1 / n_in
That’s LeCun normal: draw weights from N(0, 1/n_in), i.e. standard
deviation sqrt(1/n_in). This is the default for Flax’s nn.Dense.
Xavier/Glorot normal targets the harmonic mean of fan-in and fan-out:
σ²_w = 2 / (n_in + n_out)
Good for tanh/sigmoid activations. He/Kaiming normal targets ReLU networks:
σ²_w = 2 / n_in (accounts for ReLU zeroing half the activations)
Flax initializers API
Flax exposes initializers in flax.linen.initializers (aliased as
nn.initializers). All initializers are factories — they return a function
that takes (key, shape, dtype) and returns an array.
nn.initializers.lecun_normal() # LeCun normal (default for Dense kernel)
nn.initializers.zeros # note: no () — it's already a callable
nn.initializers.glorot_uniform() # Xavier uniform
nn.initializers.glorot_normal() # Xavier normal
nn.initializers.he_normal() # He/Kaiming normal
nn.initializers.he_uniform() # He/Kaiming uniform
nn.initializers.ones # constant initializer
nn.initializers.normal(stddev=0.01) # custom normal
Pass them to nn.Dense (or any Flax layer) as keyword arguments:
nn.Dense(
features=self.features,
kernel_init=nn.initializers.lecun_normal(),
bias_init=nn.initializers.zeros,
)
Worked contrast example
import jax
import jax.numpy as jnp
import flax.linen as nn
class NormalInit(nn.Module):
features: int
@nn.compact
def __call__(self, x):
return nn.Dense(self.features,
kernel_init=nn.initializers.normal(stddev=0.01))(x)
class LecunInit(nn.Module):
features: int
@nn.compact
def __call__(self, x):
return nn.Dense(self.features,
kernel_init=nn.initializers.lecun_normal())(x)
key = jax.random.PRNGKey(0)
x = jnp.ones((64,))
p_small = NormalInit(features=64).init(key, x)
p_lecun = LecunInit(features=64).init(key, x)
k_small = p_small["params"]["Dense_0"]["kernel"]
k_lecun = p_lecun["params"]["Dense_0"]["kernel"]
print(f"small stddev kernel std: {k_small.std():.4f}") # ~0.0100
print(f"lecun kernel std: {k_lecun.std():.4f}") # ~0.1250 = 1/sqrt(64)
The LeCun kernel has ~12.5x larger standard deviation — exactly 1/sqrt(n_in) as
expected. For a 64-unit wide network this matters a lot.
Bias vs kernel convention
Biases are almost always initialized to zero. The intuition: a nonzero bias offset at init creates a “dead zone” — some neurons activate strongly even before seeing data. Zero bias lets the network start from a neutral position and learn the offset it needs.
Kernel initialization is where the real work is done.
Common pitfalls
-
Forgetting
()on factory initializers:lecun_normalis a function, not a callable instance. You must call it:lecun_normal(). Passinglecun_normal(without()) will fail at runtime. Note:zerosandonesare already callables (not factories), so they don’t need(). -
Mismatching initializer to activation: using
glorot_normal(designed for tanh/sigmoid) on a ReLU network works, but is suboptimal. For ReLU, preferhe_normal. -
Ignoring dtype: most initializers produce float32 by default. If you’re
using bfloat16 training, specify
dtypeexplicitly.
Problem
Build a Dense Flax Module using @nn.compact with explicit initializers:
lecun_normal() for the kernel and zeros for the bias. Init and apply.
Inputs:
-
seed: int (passed as float — cast to int). -
x: 1-D JAX array. -
features: int (cast to int).
Output: 1-D array of length features.
Hints
Sign in to attempt this problem and view the solution.