medium primitives

NNX Multi-Task Loss

Why this matters

Many real architectures predict more than one thing at once: a detector that outputs both class and bounding box, a language model that predicts both next token and entity tag, an autoencoder that reconstructs both pixels and depth. The standard pattern is a shared trunk + several task-specific heads, trained with a weighted sum of the per-task losses.

        ┌──── head_a ──── out_a ──── loss_a (vs y_a)
x ── trunk ┤
        └──── head_b ──── out_b ──── loss_b (vs y_b)

combined_loss = α * loss_a + (1 - α) * loss_b

The trunk’s gradients receive contributions from BOTH heads, so the trunk learns features useful for both tasks at once. The heads specialize.

The model

class TwoHead(nnx.Module):
    def __init__(self, d_in, hidden, rngs):
        self.trunk = nnx.Linear(d_in, hidden, rngs=rngs)
        self.head_a = nnx.Linear(hidden, 1, rngs=rngs)
        self.head_b = nnx.Linear(hidden, 1, rngs=rngs)

    def __call__(self, x):
        h = jax.nn.relu(self.trunk(x))
        return self.head_a(h), self.head_b(h)

A model that returns a tuple is fine in nnx. nnx.value_and_grad only requires the LOSS to be a scalar; the model’s output is free-form.

The loss

def loss_fn(model, x, y_a, y_b):
    out_a, out_b = model(x)
    loss_a = jnp.mean((out_a - y_a) ** 2)
    loss_b = jnp.mean((out_b - y_b) ** 2)
    return alpha * loss_a + (1.0 - alpha) * loss_b

Notice that loss_a and loss_b are computed inside loss_fn so they’re part of the differentiation graph. After nnx.value_and_grad, the grads on the trunk are α * grads_from_loss_a + (1-α) * grads_from_loss_b — the chain rule handles the weighting automatically.

Choosing α

The simplest choice is a fixed scalar, like α = 0.5 for equal weighting. In practice people use:

  • Heuristics: weight by how “important” each task is.
  • Loss balancing: scale each loss by its current value so the gradient magnitudes are comparable.
  • Learned weights: parameterize α and learn it (the “uncertainty weighting” of Kendall et al.).

For this problem α is given as a fixed input.

Common pitfalls

  • Returning loss_a and loss_b separately and summing OUTSIDE loss_fn. Then value_and_grad only sees one of them and the grads are wrong. Combine inside the function being differentiated.
  • Not casting α to float. If it arrives as jnp.float32, the 1.0 - α subtraction works, but Python floats interact more cleanly. Cast at the boundary.
  • Using one head for both tasks. That defeats the entire point; the trunk would learn a single representation tied to a single output, with no task specialization.

Problem

Implement multi_task_loss(seed, x, y_a, y_b, alpha):

  1. Define TwoHead(nnx.Module) with trunk: nnx.Linear(D, 8), head_a: nnx.Linear(8, 1), head_b: nnx.Linear(8, 1). __call__(self, x) returns (out_a, out_b) after a ReLU on the trunk.
  2. Build with nnx.Rngs(int(seed)) and wrap in nnx.Optimizer(model, optax.sgd(0.05), wrt=nnx.Param).
  3. Define loss_fn(model, x, y_a, y_b) that returns α * MSE(out_a, y_a) + (1 - α) * MSE(out_b, y_b).
  4. ONE optimizer step.
  5. Recompute loss_a_final, loss_b_final, and the combined value at the new params.
  6. Return jnp.array([combined, loss_a_final, loss_b_final]) as 1-D (3,).

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D).
  • y_a, y_b: each 2-D (N, 1).
  • alpha: float in [0, 1].

Output: 1-D (3,)[combined_loss, loss_a, loss_b] after one step.

Hints

flax nnx multi-task training

Sign in to attempt this problem and view the solution.