We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
HF Weight Load (Kernel Transpose)
Why this matters
Most pre-trained checkpoints in the wild — especially ones in the HuggingFace Hub — were saved from PyTorch. To reuse them in Flax you have to port the weights. The bytes are the same; the layout convention is not.
| API | Linear weight shape | Computation |
|---|---|---|
| PyTorch / HF |
(out, in) |
y = x @ W.T + b |
Flax nn.Dense |
(in, out) |
y = x @ kernel + bias |
PyTorch stores the matrix that gets transposed at multiply time; Flax stores it pre-transposed. So:
flax_kernel = pt_weight.T
Bias has the same layout in both ((out,)), so no transpose needed.
Get this wrong and the model will:
- Crash if the input dim mismatch is detected (dot product won’t compile).
-
Or silently produce garbage if shapes happen to be compatible (e.g., square Dense layers —
D=D).
The square-layer case is the silent-killer. Always sanity check against a known input after porting (e.g., compare logits on one sample to the original PyTorch model).
Beyond Linear: other layout gotchas
-
Conv2D: PyTorch
(out, in, kH, kW), Flax(kH, kW, in, out). A two-axis transpose. -
LayerNorm: same
(D,)shape forweight/bias, no transpose. -
Embeddings: same
(vocab, D)shape, no transpose. - Multi-head attention: depends on the impl — HF often packs Q, K, V into one weight; Flax tends to use three separate Denses. You may need to split the packed weight before transposing.
How real “load HF into Flax” libraries do it
Tools like flax_to_pt or HF’s own convert_*_pt_to_jax.py scripts
walk the PyTorch state_dict and apply per-layer rules:
pt_state = torch.load("model.pt")
flax_params = {}
for pt_key, pt_tensor in pt_state.items():
flax_key = remap(pt_key) # "encoder.layer.0.weight" -> "Encoder/Dense_0/kernel"
if "weight" in pt_key and is_linear(pt_key):
flax_params[flax_key] = pt_tensor.T.numpy()
else:
flax_params[flax_key] = pt_tensor.numpy()
The transpose-or-not rules are the heart of the converter. This problem isolates that single rule.
Problem
You’re given:
-
seed,x: used to init the Flaxnn.Dense(features)(we do NOT use the random init values; we just need the param tree shape). -
hf_kernel_flat: 1-D, holding an HF-style weight matrix flattened from shape(features, in_dim). -
hf_bias_flat: 1-D, holding the HF bias of shape(features,).
Steps:
-
Reshape
hf_kernel_flatto(features, in_dim)(the HF convention). -
Transpose to get the Flax kernel of shape
(in_dim, features). -
Build a new
params = {"kernel": flax_kernel, "bias": hf_bias}— overriding the random init. -
model.apply({"params": new_params}, x). - Return the flattened apply result.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, in_dim). -
hf_kernel_flat: 1-D(features * in_dim,). -
hf_bias_flat: 1-D(features,). -
features: float (cast to int).
Output: 1-D (N * features,) — the flattened apply result.
Hints
Sign in to attempt this problem and view the solution.