We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Param Sharing — One Module, Two Call Sites
Why this matters
“Parameter sharing” means two call sites use the SAME weights. The model has fewer parameters than it would otherwise; both call sites learn the same transformation. Concrete uses:
-
Tied input/output embeddings (covered in pos 41 and the
mini-GPT/mini-T5 problems): the
(V, D)embedding matrix is reused as the output projection’s weights. - Universal Transformer / ALBERT: the same Transformer block is applied at every depth — one block, used N times. Drastically cuts parameter count vs an N-block stack.
- Siamese networks: both branches share the encoder; only the application differs (e.g., one input vs another, then compare).
- Cycle-GAN’s identity mapping: the same generator is applied to two different domains.
- Recursive nets: a single child-combining layer applied at every internal node of a tree.
In Flax, sharing parameters is done by sharing the Module
instance. Define the sub-module ONCE in setup(). Reference it
multiple times in __call__. Flax sees only one instance, so it
creates only one set of params — and both call sites read the
same values.
The pattern: setup() with one named sub-module
class SharedDense(nn.Module):
hidden: int
def setup(self):
self.shared = nn.Dense(self.hidden, name="shared")
def __call__(self, x):
h1 = jax.nn.relu(self.shared(x))
h2 = jax.nn.relu(self.shared(h1))
return h2
Both calls go through self.shared. The params dict is just:
{
"params": {
"shared": {"kernel": ..., "bias": ...}
}
}
ONE entry. Compare to the non-shared version:
@nn.compact
def __call__(self, x):
x = nn.Dense(hidden)(x) # creates Dense_0
x = nn.Dense(hidden)(x) # creates Dense_1, FRESH PARAMS
return x
Two calls to nn.Dense(hidden) create TWO separate sub-modules
with TWO sets of params — no sharing.
Why the second call works (shape constraint)
nn.Dense(hidden) projects from the input’s last dim to
hidden. For the SAME Dense to be called TWICE in sequence, its
kernel needs to be square: input dim and output dim must match.
That means our x must have last-dim equal to hidden. The test
cases enforce this.
In production param-sharing scenarios — like Universal Transformer — the shared block is self-similar (input shape = output shape) by design. A whole Transformer block satisfies that naturally.
setup() vs @nn.compact
The two ways to define a Flax Module:
| Style | When |
|---|---|
@nn.compact |
Submodules created inline in __call__. |
setup() |
Submodules declared up front; reused freely. |
setup() is the only way to share a sub-module across
multiple call sites within a single __call__. Trying to share
inside @nn.compact is awkward — every call to nn.Dense(hidden)
inside a compact body creates a new layer.
Common pitfalls
-
Using
@nn.compactand trying to “share” by callingnn.Dense(hidden)twice: creates two separate Dense layers with different params. Usesetup()instead. -
Forgetting
name=: works but the auto-naming becomes version-fragile. With explicitname="shared", the params key is stable across Flax versions. -
Mismatched shapes for the second call: if
hidden≠in_dim, the second call fails with a shape error. Either design the architecture self-similar, or insert a separate projection. -
Calling
setup()yourself: don’t. Flax calls it automatically once perinit/apply.
Verifying the sharing
A quick sanity check: count params with jax.tree_util.tree_leaves.
A shared Dense gives 2 leaves (kernel + bias). Non-shared gives
4 leaves. The params dict’s keys also tell the story: one
shared key vs two (Dense_0, Dense_1).
Problem
Implement SharedDense(nn.Module):
-
hidden: intfield. -
setup(self)declares ONE sub-module:self.shared = nn.Dense(self.hidden, name="shared"). -
__call__(self, x)does TWO calls throughself.shared:h1 = jax.nn.relu(self.shared(x))thenh2 = jax.nn.relu(self.shared(h1)). Returnsh2.
Init on (seed, x), apply, return out.reshape(-1).
The test cases choose x with in_dim == hidden so the same
Dense can be applied twice in sequence.
Inputs:
-
seed: float (cast to int) — PRNG seed. -
x: 2-D(N, hidden)(input dim equalshidden). -
hidden: float (cast to int) — Dense output dim, also equals input dim.
Output: 1-D, length N * hidden.
Hints
Sign in to attempt this problem and view the solution.