medium primitives

Param Sharing — One Module, Two Call Sites

Why this matters

“Parameter sharing” means two call sites use the SAME weights. The model has fewer parameters than it would otherwise; both call sites learn the same transformation. Concrete uses:

  • Tied input/output embeddings (covered in pos 41 and the mini-GPT/mini-T5 problems): the (V, D) embedding matrix is reused as the output projection’s weights.
  • Universal Transformer / ALBERT: the same Transformer block is applied at every depth — one block, used N times. Drastically cuts parameter count vs an N-block stack.
  • Siamese networks: both branches share the encoder; only the application differs (e.g., one input vs another, then compare).
  • Cycle-GAN’s identity mapping: the same generator is applied to two different domains.
  • Recursive nets: a single child-combining layer applied at every internal node of a tree.

In Flax, sharing parameters is done by sharing the Module instance. Define the sub-module ONCE in setup(). Reference it multiple times in __call__. Flax sees only one instance, so it creates only one set of params — and both call sites read the same values.

The pattern: setup() with one named sub-module

class SharedDense(nn.Module):
    hidden: int

    def setup(self):
        self.shared = nn.Dense(self.hidden, name="shared")

    def __call__(self, x):
        h1 = jax.nn.relu(self.shared(x))
        h2 = jax.nn.relu(self.shared(h1))
        return h2

Both calls go through self.shared. The params dict is just:

{
  "params": {
    "shared": {"kernel": ..., "bias": ...}
  }
}

ONE entry. Compare to the non-shared version:

@nn.compact
def __call__(self, x):
    x = nn.Dense(hidden)(x)   # creates Dense_0
    x = nn.Dense(hidden)(x)   # creates Dense_1, FRESH PARAMS
    return x

Two calls to nn.Dense(hidden) create TWO separate sub-modules with TWO sets of params — no sharing.

Why the second call works (shape constraint)

nn.Dense(hidden) projects from the input’s last dim to hidden. For the SAME Dense to be called TWICE in sequence, its kernel needs to be square: input dim and output dim must match. That means our x must have last-dim equal to hidden. The test cases enforce this.

In production param-sharing scenarios — like Universal Transformer — the shared block is self-similar (input shape = output shape) by design. A whole Transformer block satisfies that naturally.

setup() vs @nn.compact

The two ways to define a Flax Module:

Style When
@nn.compact Submodules created inline in __call__.
setup() Submodules declared up front; reused freely.

setup() is the only way to share a sub-module across multiple call sites within a single __call__. Trying to share inside @nn.compact is awkward — every call to nn.Dense(hidden) inside a compact body creates a new layer.

Common pitfalls

  • Using @nn.compact and trying to “share” by calling nn.Dense(hidden) twice: creates two separate Dense layers with different params. Use setup() instead.
  • Forgetting name=: works but the auto-naming becomes version-fragile. With explicit name="shared", the params key is stable across Flax versions.
  • Mismatched shapes for the second call: if hiddenin_dim, the second call fails with a shape error. Either design the architecture self-similar, or insert a separate projection.
  • Calling setup() yourself: don’t. Flax calls it automatically once per init/apply.

Verifying the sharing

A quick sanity check: count params with jax.tree_util.tree_leaves. A shared Dense gives 2 leaves (kernel + bias). Non-shared gives 4 leaves. The params dict’s keys also tell the story: one shared key vs two (Dense_0, Dense_1).

Problem

Implement SharedDense(nn.Module):

  • hidden: int field.
  • setup(self) declares ONE sub-module: self.shared = nn.Dense(self.hidden, name="shared").
  • __call__(self, x) does TWO calls through self.shared: h1 = jax.nn.relu(self.shared(x)) then h2 = jax.nn.relu(self.shared(h1)). Returns h2.

Init on (seed, x), apply, return out.reshape(-1).

The test cases choose x with in_dim == hidden so the same Dense can be applied twice in sequence.

Inputs:

  • seed: float (cast to int) — PRNG seed.
  • x: 2-D (N, hidden) (input dim equals hidden).
  • hidden: float (cast to int) — Dense output dim, also equals input dim.

Output: 1-D, length N * hidden.

Hints

flax param-sharing setup

Sign in to attempt this problem and view the solution.