We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Learned Position Embedding
Why this matters
Sinusoidal encodings are clever but inflexible — every model architecture
inherits Vaswani’s specific frequency schedule. Learned position
embeddings swap that for a free (max_T, D) parameter matrix, training
one vector per absolute position. GPT-1, GPT-2, BERT, and most early
Transformers use learned positions.
The trade-offs vs sinusoidal:
| Property | Sinusoidal | Learned |
|---|---|---|
| Trainable parameters | 0 | max_T·D |
| Generalises past max_T | Yes | No |
| Fits the dataset | Generic | Tunable |
| Inits from a known prior | Yes | Random |
Learned wins when you train on long sequences and stay within max_T.
Sinusoidal wins for length extrapolation. Modern long-context models
(RoPE, ALiBi) sidestep both.
The Flax pattern
A nn.Module declaring one parameter via self.param:
class LearnedPos(nn.Module):
max_T: int
d_model: int
@nn.compact
def __call__(self, T):
pos_embed = self.param(
"pos_embed",
nn.initializers.normal(stddev=0.02),
(self.max_T, self.d_model),
)
return pos_embed[:T]
Three pieces:
-
Class attributes
max_Tandd_model— Flax dataclass fields, set on construction. -
self.param(name, init_fn, shape)— declares a learnable parameter on first init; on subsequent calls Flax returns the trained value. -
pos_embed[:T]— slice the firstTrows for the current sequence.
Why normal(stddev=0.02)?
0.02 is the BERT/GPT default for embedding-style parameters: small
enough that the initial pre-LayerNorm magnitudes stay sane, large enough
that gradients flow on step 0. nn.initializers.normal returns a
truncated-normal-style sampler around mean 0.
Other reasonable choices: xavier_uniform(), zeros (then learn from
scratch — slower convergence), lecun_normal().
Common pitfalls
-
Using
nn.Embedfor positions: works but wastes a lookup. The input is justarange(T), so a direct slice is cheaper. -
pos_embed[T]vspos_embed[:T]: the former is one row at indexT; we want the firstTrows. -
max_Ttoo small at eval: if your eval sequence is longer thanmax_T,pos_embed[:T]returns garbage / errors. Sinusoidal would handle it; learned can’t. -
Forgetting
@nn.compact: without it, Flax requires explicitsetupdeclaringself.pos_embed. The compact decorator lets you declare params inline in__call__.
Problem
Implement learned_pos_embed_forward(seed, T, d_model):
-
Cast
T,d_modeltoint. Usemax_T = 32. -
Define a
nn.Module(compact) that declarespos_embedof shape(max_T, d_model)initialised withnormal(stddev=0.02). -
The module’s forward takes
Tand returnspos_embed[:T]. -
Init with
jax.random.PRNGKey(seed)and apply. - Return the output flattened to 1-D.
Inputs:
-
seed: int. -
T: int — sequence length, ≤max_T. -
d_model: int.
Output: 1-D array of length T * d_model.
Hints
Sign in to attempt this problem and view the solution.