We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Implement Positional Embed
Why this matters
Transformers process tokens in parallel — there’s no recurrence to
encode order. Positional embeddings are the workaround: a learned (or
sinusoidal) signal added to the token embeddings to tell the model
which position each token sits at. BERT and GPT-2 use learned
positional embeddings — a (max_T, d_model) matrix initialized to
small normal values, sliced to the actual sequence length on each
forward.
Modern LLaMA-style models tend to use RoPE instead (problem 39 in this track). But learned positional embeddings remain in BERT, GPT-2, ViT, and many vision transformers — and they’re the simplest starting point for the concept.
API
A single trainable (max_T, d_model) matrix:
class MyPositionalEmbed(nnx.Module):
def __init__(self, max_T, d_model, rngs):
key = rngs.params()
self.pos_embed = nnx.Param(
jax.random.normal(key, (max_T, d_model)) * 0.02
)
def __call__(self, T):
return self.pos_embed[:T]
The forward slices the table to the actual sequence length T. This
is the GPT-2/BERT pattern: allocate max_T (e.g., 1024 or 512) up
front and slice down at use time. If a sequence longer than max_T
arrives, you crash — by design — because you’ve never seen those
positions during training.
The 0.02 init scale
BERT and GPT-2 initialize positional embeddings (and token embeddings)
with normal(stddev=0.02). The small scale is deliberate: positional
embeddings get ADDED to token embeddings, so they shouldn’t dominate
until the model learns what to do with them. With unit-variance
init, the position signal would swamp the token identity signal at
step 0 and the model would have to relearn this scale.
Worked example
rngs = nnx.Rngs(0)
model = MyPositionalEmbed(max_T=64, d_model=8, rngs=rngs)
print(model.pos_embed.value.shape) # (64, 8)
out = model(4) # (4, 8) — first 4 positions
Linen contrast
# Linen — for contrast.
class MyPositionalEmbed(nn.Module):
max_T: int
d_model: int
@nn.compact
def __call__(self, T):
E = self.param("pos_embed", nn.initializers.normal(0.02),
(self.max_T, self.d_model))
return E[:T]
Same math, same init. nnx removes init/apply.
Combining with token embeddings
In a real model:
tok = token_embed(ids) # (T, d_model)
pos = pos_embed(T) # (T, d_model)
h = tok + pos # (T, d_model)
The + is the whole story — token identity and position are
superimposed in the same vector. Surprisingly, the model learns to
disentangle them.
Common pitfalls
-
Forgetting
* 0.02. Standard normal makes positional embeddings dominate token embeddings; convergence is much slower. -
max_Tbaked fromT. If you allocate the table to the training-timeT, eval-time longer sequences will index out of bounds. Allocate from a configuredmax_T(we use 64 here). -
Tarriving as float. Cast to int before slicing. -
Using indexing
pos_embed[ids]style. That’s for token embeddings (gather by id). Positional embeddings just slice[:T]— the rows ARE the positions.
Problem
Write positional_embed_forward(seed, T, d_model):
-
Define
MyPositionalEmbed(nnx.Module)with onennx.Param:self.pos_embedshape(max_T, d_model)initialized asjax.random.normal(key, shape) * 0.02. Usemax_T = 64. -
__call__(self, T): returnself.pos_embed[:T]. -
Cast
Tandd_modelfrom float to int. -
Build with
nnx.Rngs(int(seed)), instantiate (max_T=64,d_model=int(d_model)), returnmodel(int(T)).reshape(-1).
Inputs:
-
seed: int (passed as float). -
T: int (passed as float — cast). -
d_model: int (passed as float — cast).
Output: 1-D array of length T * d_model.
Hints
Sign in to attempt this problem and view the solution.