We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Multi-Task Two-Head Loss
Why this matters
Real ML systems rarely train on a single task. Multi-task learning shares a representation (the “trunk”) across several prediction targets, each with its own “head”:
- CLIP: image trunk → text-similarity head + classification head.
-
Detectors (Faster-RCNN, YOLO): backbone → bbox-regression head
- classification head + objectness head.
- MTL recommendation: user embedding trunk → click head + dwell-time head + revenue head.
The wins:
- Data efficiency: one trunk gets gradient signal from ALL tasks, not just one — more learned representations per parameter.
- Regularization: each head pulls the trunk in a different direction; nothing dominates.
- Single inference: one forward pass yields all outputs.
The architecture
A 2-head model on top of a shared trunk:
x ──► Dense(H) ──► ReLU ──► h
│
├──► Dense(1) ──► out_a
└──► Dense(1) ──► out_b
The two heads start independently in init (different RNG splits).
They share NOTHING except the input h.
class TwoHead(nn.Module):
hidden: int = 8
@nn.compact
def __call__(self, x):
h = nn.relu(nn.Dense(self.hidden)(x))
out_a = nn.Dense(1)(h)
out_b = nn.Dense(1)(h)
return out_a, out_b
The combined loss
The trickiest part of multi-task learning is the loss combination. The simplest is fixed weighting:
L = α * L_a + (1 - α) * L_b
Where α ∈ [0, 1] decides “how much do we care about task A vs B.”
α = 0.5 is symmetric. α = 1.0 ignores task B.
More sophisticated approaches (uncertainty weighting, GradNorm) learn α per-task or per-step, but the principle is the same: the gradient is a weighted sum of the per-task gradients.
L_a = jnp.mean((pa - y_a) ** 2)
L_b = jnp.mean((pb - y_b) ** 2)
L = α * L_a + (1 - α) * L_b
Common pitfalls
-
Both heads sharing a parameter accidentally: in JAX/Flax this
requires explicit weight tying — by default, two separate
Dense(1)submodules get fully independent params. So this is hard to do by mistake (good news). -
Forgetting to flatten heads:
Dense(1)returns(N, 1). Ifyis(N,), you’ll silently broadcast to(N, N)..reshape(-1). -
Imbalanced loss scales: if task A is regression in
[0, 100]and task B is in[0, 1], a pure α-mix dominates with task A. In practice, normalize per-task loss to comparable magnitudes first.
Problem
-
Build
TwoHead(hidden=8)and init fromseed. -
Forward
x; get(out_a, out_b). -
Compute
L_a = MSE(out_a, y_a),L_b = MSE(out_b, y_b). -
Compute combined
L = α * L_a + (1 - α) * L_b. -
Return
[L, L_a, L_b]as a 1-D(3,)array.
No training step here — just a single forward + loss computation so we can isolate the multi-task structure.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D). -
y_a: 1-D(N,). -
y_b: 1-D(N,). -
alpha: float in[0, 1].
Output: 1-D (3,) — [combined_loss, loss_a, loss_b].
Hints
Sign in to attempt this problem and view the solution.