We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Bridge: Load HuggingFace Weights into NNX
Why this matters
Most pretrained model weights you’ll encounter — Llama, GPT-2, BERT,
Mistral, Qwen, you name it — ship as PyTorch state dicts. The
HuggingFace transformers library publishes them with PyTorch
layout conventions even when wrapped behind framework-agnostic APIs.
To use those weights in JAX/Flax/nnx, you have to translate the
layout. The single most important difference: PyTorch
nn.Linear.weight has shape (out_features, in_features), while
Flax (Linen and nnx alike) uses (in_features, out_features). One
transpose. Get this right and the rest is mechanical.
The convention difference, in one diagram
PyTorch:
y = x @ W.T + b # W is (out, in), so W.T is (in, out)
weight.shape = (out, in)
Flax:
y = x @ kernel + b # kernel is (in, out)
kernel.shape = (in, out)
The math is the same; the stored layout differs. To port:
flax_kernel = pt_weight.T # transpose (out, in) → (in, out)
The full port
def load_hf_weights(seed, x, hf_kernel_flat, hf_bias_flat, features):
F = int(features)
in_features = int(x.shape[-1])
# Build a fresh nnx Linear (init values will be overwritten).
rngs = nnx.Rngs(int(seed))
model = nnx.Linear(in_features=in_features, out_features=F, rngs=rngs)
# HF ships kernel flat in PyTorch layout (out, in).
hf_kernel = hf_kernel_flat.reshape(F, in_features)
nnx_kernel = hf_kernel.T # (in, out) for Flax
# Overwrite the Variables.
model.kernel.value = nnx_kernel
model.bias.value = hf_bias_flat
return model(x)
Why the inputs are flat: in production the HF tensor is already
2-D, but the test harness only handles flat arrays. The reshape →
transpose pattern is exactly what you’d do on a real state_dict
entry.
Reading test cases
Test 1: HF kernel is [[1,0,0],[0,1,0],[0,0,1],[0,0,0]] ((F=4, in=3)). The first three rows are an identity selector; the fourth
is zero. So model([1, 2, 3]) = [1, 2, 3, 0].
Test 2: HF kernel is all ones, F=2, in=3. Each row sums the
input: model([1, 2, 3]) = [6+1, 6-1] = [7, 5] after bias [1, -1].
Test 3: HF kernel is a permutation cycling (0, 1, 2) → (1, 2, 0).
Plus a constant bias 0.5.
Test 4 (hidden): HF kernel encodes a 45° rotation, applied to (1, 0) — outputs the unit vector at 45°.
Each case is interpretable so you can verify the transpose is right. Forget the transpose and you’d get garbage.
Where this fits in production
A real port from a HuggingFace checkpoint:
sd = torch.load("model.pt")
nnx_model = MyNNXModel(...)
nnx_model.attn.q_proj.kernel.value = jnp.asarray(sd["attn.q_proj.weight"]).T
nnx_model.attn.q_proj.bias.value = jnp.asarray(sd["attn.q_proj.bias"])
# ... and so on for every layer
The recipe is always: jnp.asarray(pt_tensor).T for Linear/Dense
weights; biases unchanged. (Convolutions and attention have other
quirks; cover those when you hit them.)
Common pitfalls
-
Forgetting the transpose. The shape
(F, in_features)vs(in_features, F)might still apply (x @ kernelwon’t error if you have a square layer withF == in_features!) but the output is mathematically wrong. -
Reshape order.
reshape(F, in_features)keeps the natural PyTorch row-major order. Transposing flips it.reshape(in_features, F)without transposing gives a different array unless the original was symmetric. -
Forgetting biases stay 1-D. Bias has shape
(out_features,)in both frameworks. No transpose needed. -
Replacing the wrapper. Always assign to
.value, nevermodel.kernel = arr.
Problem
Write load_hf_weights(seed, x, hf_kernel_flat, hf_bias_flat, features):
-
Build a fresh
nnx.Linear(in_features=int(x.shape[-1]), out_features=int(features), rngs=nnx.Rngs(int(seed))). -
Reshape
hf_kernel_flatto(features, in_features)(PyTorch layout). Transpose to(in_features, features)(Flax layout). -
Assign to
model.kernel.value. Assignhf_bias_flattomodel.bias.valuedirectly (no transpose). -
Return
model(x).
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
hf_kernel_flat: 1-D JAX array of lengthfeatures * in_features(PyTorch row-major). -
hf_bias_flat: 1-D JAX array of lengthfeatures. -
features: int (passed as float).
Output: 1-D array of length features — the model’s forward.
Hints
Sign in to attempt this problem and view the solution.