medium primitives

Multi-Task Two-Head Loss

Why this matters

Real ML systems rarely train on a single task. Multi-task learning shares a representation (the “trunk”) across several prediction targets, each with its own “head”:

  • CLIP: image trunk → text-similarity head + classification head.
  • Detectors (Faster-RCNN, YOLO): backbone → bbox-regression head
    • classification head + objectness head.
  • MTL recommendation: user embedding trunk → click head + dwell-time head + revenue head.

The wins:

  1. Data efficiency: one trunk gets gradient signal from ALL tasks, not just one — more learned representations per parameter.
  2. Regularization: each head pulls the trunk in a different direction; nothing dominates.
  3. Single inference: one forward pass yields all outputs.

The architecture

A 2-head model on top of a shared trunk:

x ──► Dense(H) ──► ReLU ──► h
                            │
                            ├──► Dense(1) ──► out_a
                            └──► Dense(1) ──► out_b

The two heads start independently in init (different RNG splits). They share NOTHING except the input h.

class TwoHead(nn.Module):
    hidden: int = 8
    @nn.compact
    def __call__(self, x):
        h = nn.relu(nn.Dense(self.hidden)(x))
        out_a = nn.Dense(1)(h)
        out_b = nn.Dense(1)(h)
        return out_a, out_b

The combined loss

The trickiest part of multi-task learning is the loss combination. The simplest is fixed weighting:

L = α * L_a + (1 - α) * L_b

Where α ∈ [0, 1] decides “how much do we care about task A vs B.” α = 0.5 is symmetric. α = 1.0 ignores task B.

More sophisticated approaches (uncertainty weighting, GradNorm) learn α per-task or per-step, but the principle is the same: the gradient is a weighted sum of the per-task gradients.

L_a = jnp.mean((pa - y_a) ** 2)
L_b = jnp.mean((pb - y_b) ** 2)
L = α * L_a + (1 - α) * L_b

Common pitfalls

  • Both heads sharing a parameter accidentally: in JAX/Flax this requires explicit weight tying — by default, two separate Dense(1) submodules get fully independent params. So this is hard to do by mistake (good news).
  • Forgetting to flatten heads: Dense(1) returns (N, 1). If y is (N,), you’ll silently broadcast to (N, N). .reshape(-1).
  • Imbalanced loss scales: if task A is regression in [0, 100] and task B is in [0, 1], a pure α-mix dominates with task A. In practice, normalize per-task loss to comparable magnitudes first.

Problem

  1. Build TwoHead(hidden=8) and init from seed.
  2. Forward x; get (out_a, out_b).
  3. Compute L_a = MSE(out_a, y_a), L_b = MSE(out_b, y_b).
  4. Compute combined L = α * L_a + (1 - α) * L_b.
  5. Return [L, L_a, L_b] as a 1-D (3,) array.

No training step here — just a single forward + loss computation so we can isolate the multi-task structure.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D).
  • y_a: 1-D (N,).
  • y_b: 1-D (N,).
  • alpha: float in [0, 1].

Output: 1-D (3,)[combined_loss, loss_a, loss_b].

Hints

flax multi-task loss

Sign in to attempt this problem and view the solution.