medium primitives

NNX Implement Positional Embed

Why this matters

Transformers process tokens in parallel — there’s no recurrence to encode order. Positional embeddings are the workaround: a learned (or sinusoidal) signal added to the token embeddings to tell the model which position each token sits at. BERT and GPT-2 use learned positional embeddings — a (max_T, d_model) matrix initialized to small normal values, sliced to the actual sequence length on each forward.

Modern LLaMA-style models tend to use RoPE instead (problem 39 in this track). But learned positional embeddings remain in BERT, GPT-2, ViT, and many vision transformers — and they’re the simplest starting point for the concept.

API

A single trainable (max_T, d_model) matrix:

class MyPositionalEmbed(nnx.Module):
    def __init__(self, max_T, d_model, rngs):
        key = rngs.params()
        self.pos_embed = nnx.Param(
            jax.random.normal(key, (max_T, d_model)) * 0.02
        )

    def __call__(self, T):
        return self.pos_embed[:T]

The forward slices the table to the actual sequence length T. This is the GPT-2/BERT pattern: allocate max_T (e.g., 1024 or 512) up front and slice down at use time. If a sequence longer than max_T arrives, you crash — by design — because you’ve never seen those positions during training.

The 0.02 init scale

BERT and GPT-2 initialize positional embeddings (and token embeddings) with normal(stddev=0.02). The small scale is deliberate: positional embeddings get ADDED to token embeddings, so they shouldn’t dominate until the model learns what to do with them. With unit-variance init, the position signal would swamp the token identity signal at step 0 and the model would have to relearn this scale.

Worked example

rngs = nnx.Rngs(0)
model = MyPositionalEmbed(max_T=64, d_model=8, rngs=rngs)
print(model.pos_embed.value.shape)         # (64, 8)
out = model(4)                              # (4, 8) — first 4 positions

Linen contrast

# Linen — for contrast.
class MyPositionalEmbed(nn.Module):
    max_T: int
    d_model: int
    @nn.compact
    def __call__(self, T):
        E = self.param("pos_embed", nn.initializers.normal(0.02),
                       (self.max_T, self.d_model))
        return E[:T]

Same math, same init. nnx removes init/apply.

Combining with token embeddings

In a real model:

tok = token_embed(ids)                     # (T, d_model)
pos = pos_embed(T)                          # (T, d_model)
h = tok + pos                               # (T, d_model)

The + is the whole story — token identity and position are superimposed in the same vector. Surprisingly, the model learns to disentangle them.

Common pitfalls

  • Forgetting * 0.02. Standard normal makes positional embeddings dominate token embeddings; convergence is much slower.
  • max_T baked from T. If you allocate the table to the training-time T, eval-time longer sequences will index out of bounds. Allocate from a configured max_T (we use 64 here).
  • T arriving as float. Cast to int before slicing.
  • Using indexing pos_embed[ids] style. That’s for token embeddings (gather by id). Positional embeddings just slice [:T] — the rows ARE the positions.

Problem

Write positional_embed_forward(seed, T, d_model):

  1. Define MyPositionalEmbed(nnx.Module) with one nnx.Param: self.pos_embed shape (max_T, d_model) initialized as jax.random.normal(key, shape) * 0.02. Use max_T = 64.
  2. __call__(self, T): return self.pos_embed[:T].
  3. Cast T and d_model from float to int.
  4. Build with nnx.Rngs(int(seed)), instantiate (max_T=64, d_model=int(d_model)), return model(int(T)).reshape(-1).

Inputs:

  • seed: int (passed as float).
  • T: int (passed as float — cast).
  • d_model: int (passed as float — cast).

Output: 1-D array of length T * d_model.

Hints

flax nnx positional-embedding reimplementation

Sign in to attempt this problem and view the solution.