medium primitives

Learned Position Embedding

Why this matters

Sinusoidal encodings are clever but inflexible — every model architecture inherits Vaswani’s specific frequency schedule. Learned position embeddings swap that for a free (max_T, D) parameter matrix, training one vector per absolute position. GPT-1, GPT-2, BERT, and most early Transformers use learned positions.

The trade-offs vs sinusoidal:

Property Sinusoidal Learned
Trainable parameters 0 max_T·D
Generalises past max_T Yes No
Fits the dataset Generic Tunable
Inits from a known prior Yes Random

Learned wins when you train on long sequences and stay within max_T. Sinusoidal wins for length extrapolation. Modern long-context models (RoPE, ALiBi) sidestep both.

The Flax pattern

A nn.Module declaring one parameter via self.param:

class LearnedPos(nn.Module):
    max_T: int
    d_model: int

    @nn.compact
    def __call__(self, T):
        pos_embed = self.param(
            "pos_embed",
            nn.initializers.normal(stddev=0.02),
            (self.max_T, self.d_model),
        )
        return pos_embed[:T]

Three pieces:

  1. Class attributes max_T and d_model — Flax dataclass fields, set on construction.
  2. self.param(name, init_fn, shape) — declares a learnable parameter on first init; on subsequent calls Flax returns the trained value.
  3. pos_embed[:T] — slice the first T rows for the current sequence.

Why normal(stddev=0.02)?

0.02 is the BERT/GPT default for embedding-style parameters: small enough that the initial pre-LayerNorm magnitudes stay sane, large enough that gradients flow on step 0. nn.initializers.normal returns a truncated-normal-style sampler around mean 0.

Other reasonable choices: xavier_uniform(), zeros (then learn from scratch — slower convergence), lecun_normal().

Common pitfalls

  • Using nn.Embed for positions: works but wastes a lookup. The input is just arange(T), so a direct slice is cheaper.
  • pos_embed[T] vs pos_embed[:T]: the former is one row at index T; we want the first T rows.
  • max_T too small at eval: if your eval sequence is longer than max_T, pos_embed[:T] returns garbage / errors. Sinusoidal would handle it; learned can’t.
  • Forgetting @nn.compact: without it, Flax requires explicit setup declaring self.pos_embed. The compact decorator lets you declare params inline in __call__.

Problem

Implement learned_pos_embed_forward(seed, T, d_model):

  1. Cast T, d_model to int. Use max_T = 32.
  2. Define a nn.Module (compact) that declares pos_embed of shape (max_T, d_model) initialised with normal(stddev=0.02).
  3. The module’s forward takes T and returns pos_embed[:T].
  4. Init with jax.random.PRNGKey(seed) and apply.
  5. Return the output flattened to 1-D.

Inputs:

  • seed: int.
  • T: int — sequence length, ≤ max_T.
  • d_model: int.

Output: 1-D array of length T * d_model.

Hints

flax position-encoding embedding

Sign in to attempt this problem and view the solution.