hard primitives

NNX Bridge: Load HuggingFace Weights into NNX

Why this matters

Most pretrained model weights you’ll encounter — Llama, GPT-2, BERT, Mistral, Qwen, you name it — ship as PyTorch state dicts. The HuggingFace transformers library publishes them with PyTorch layout conventions even when wrapped behind framework-agnostic APIs.

To use those weights in JAX/Flax/nnx, you have to translate the layout. The single most important difference: PyTorch nn.Linear.weight has shape (out_features, in_features), while Flax (Linen and nnx alike) uses (in_features, out_features). One transpose. Get this right and the rest is mechanical.

The convention difference, in one diagram

PyTorch:

y = x @ W.T + b              # W is (out, in), so W.T is (in, out)
weight.shape = (out, in)

Flax:

y = x @ kernel + b            # kernel is (in, out)
kernel.shape = (in, out)

The math is the same; the stored layout differs. To port:

flax_kernel = pt_weight.T     # transpose (out, in) → (in, out)

The full port

def load_hf_weights(seed, x, hf_kernel_flat, hf_bias_flat, features):
    F = int(features)
    in_features = int(x.shape[-1])

    # Build a fresh nnx Linear (init values will be overwritten).
    rngs  = nnx.Rngs(int(seed))
    model = nnx.Linear(in_features=in_features, out_features=F, rngs=rngs)

    # HF ships kernel flat in PyTorch layout (out, in).
    hf_kernel  = hf_kernel_flat.reshape(F, in_features)
    nnx_kernel = hf_kernel.T                  # (in, out) for Flax

    # Overwrite the Variables.
    model.kernel.value = nnx_kernel
    model.bias.value   = hf_bias_flat
    return model(x)

Why the inputs are flat: in production the HF tensor is already 2-D, but the test harness only handles flat arrays. The reshape → transpose pattern is exactly what you’d do on a real state_dict entry.

Reading test cases

Test 1: HF kernel is [[1,0,0],[0,1,0],[0,0,1],[0,0,0]] ((F=4, in=3)). The first three rows are an identity selector; the fourth is zero. So model([1, 2, 3]) = [1, 2, 3, 0].

Test 2: HF kernel is all ones, F=2, in=3. Each row sums the input: model([1, 2, 3]) = [6+1, 6-1] = [7, 5] after bias [1, -1].

Test 3: HF kernel is a permutation cycling (0, 1, 2) → (1, 2, 0). Plus a constant bias 0.5.

Test 4 (hidden): HF kernel encodes a 45° rotation, applied to (1, 0) — outputs the unit vector at 45°.

Each case is interpretable so you can verify the transpose is right. Forget the transpose and you’d get garbage.

Where this fits in production

A real port from a HuggingFace checkpoint:

sd = torch.load("model.pt")
nnx_model = MyNNXModel(...)

nnx_model.attn.q_proj.kernel.value = jnp.asarray(sd["attn.q_proj.weight"]).T
nnx_model.attn.q_proj.bias.value   = jnp.asarray(sd["attn.q_proj.bias"])
# ... and so on for every layer

The recipe is always: jnp.asarray(pt_tensor).T for Linear/Dense weights; biases unchanged. (Convolutions and attention have other quirks; cover those when you hit them.)

Common pitfalls

  • Forgetting the transpose. The shape (F, in_features) vs (in_features, F) might still apply (x @ kernel won’t error if you have a square layer with F == in_features!) but the output is mathematically wrong.
  • Reshape order. reshape(F, in_features) keeps the natural PyTorch row-major order. Transposing flips it. reshape(in_features, F) without transposing gives a different array unless the original was symmetric.
  • Forgetting biases stay 1-D. Bias has shape (out_features,) in both frameworks. No transpose needed.
  • Replacing the wrapper. Always assign to .value, never model.kernel = arr.

Problem

Write load_hf_weights(seed, x, hf_kernel_flat, hf_bias_flat, features):

  1. Build a fresh nnx.Linear(in_features=int(x.shape[-1]), out_features=int(features), rngs=nnx.Rngs(int(seed))).
  2. Reshape hf_kernel_flat to (features, in_features) (PyTorch layout). Transpose to (in_features, features) (Flax layout).
  3. Assign to model.kernel.value. Assign hf_bias_flat to model.bias.value directly (no transpose).
  4. Return model(x).

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • hf_kernel_flat: 1-D JAX array of length features * in_features (PyTorch row-major).
  • hf_bias_flat: 1-D JAX array of length features.
  • features: int (passed as float).

Output: 1-D array of length features — the model’s forward.

Hints

flax nnx bridge weights huggingface pytorch

Sign in to attempt this problem and view the solution.