hard primitives

HF Weight Load (Kernel Transpose)

Why this matters

Most pre-trained checkpoints in the wild — especially ones in the HuggingFace Hub — were saved from PyTorch. To reuse them in Flax you have to port the weights. The bytes are the same; the layout convention is not.

API Linear weight shape Computation
PyTorch / HF (out, in) y = x @ W.T + b
Flax nn.Dense (in, out) y = x @ kernel + bias

PyTorch stores the matrix that gets transposed at multiply time; Flax stores it pre-transposed. So:

flax_kernel = pt_weight.T

Bias has the same layout in both ((out,)), so no transpose needed.

Get this wrong and the model will:

  1. Crash if the input dim mismatch is detected (dot product won’t compile).
  2. Or silently produce garbage if shapes happen to be compatible (e.g., square Dense layers — D=D).

The square-layer case is the silent-killer. Always sanity check against a known input after porting (e.g., compare logits on one sample to the original PyTorch model).

Beyond Linear: other layout gotchas

  • Conv2D: PyTorch (out, in, kH, kW), Flax (kH, kW, in, out). A two-axis transpose.
  • LayerNorm: same (D,) shape for weight/bias, no transpose.
  • Embeddings: same (vocab, D) shape, no transpose.
  • Multi-head attention: depends on the impl — HF often packs Q, K, V into one weight; Flax tends to use three separate Denses. You may need to split the packed weight before transposing.

How real “load HF into Flax” libraries do it

Tools like flax_to_pt or HF’s own convert_*_pt_to_jax.py scripts walk the PyTorch state_dict and apply per-layer rules:

pt_state = torch.load("model.pt")
flax_params = {}
for pt_key, pt_tensor in pt_state.items():
    flax_key = remap(pt_key)         # "encoder.layer.0.weight" -> "Encoder/Dense_0/kernel"
    if "weight" in pt_key and is_linear(pt_key):
        flax_params[flax_key] = pt_tensor.T.numpy()
    else:
        flax_params[flax_key] = pt_tensor.numpy()

The transpose-or-not rules are the heart of the converter. This problem isolates that single rule.

Problem

You’re given:

  • seed, x: used to init the Flax nn.Dense(features) (we do NOT use the random init values; we just need the param tree shape).
  • hf_kernel_flat: 1-D, holding an HF-style weight matrix flattened from shape (features, in_dim).
  • hf_bias_flat: 1-D, holding the HF bias of shape (features,).

Steps:

  1. Reshape hf_kernel_flat to (features, in_dim) (the HF convention).
  2. Transpose to get the Flax kernel of shape (in_dim, features).
  3. Build a new params = {"kernel": flax_kernel, "bias": hf_bias} — overriding the random init.
  4. model.apply({"params": new_params}, x).
  5. Return the flattened apply result.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, in_dim).
  • hf_kernel_flat: 1-D (features * in_dim,).
  • hf_bias_flat: 1-D (features,).
  • features: float (cast to int).

Output: 1-D (N * features,) — the flattened apply result.

Hints

flax weight-loading huggingface

Sign in to attempt this problem and view the solution.