We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Multi-Task Loss
Why this matters
Many real architectures predict more than one thing at once: a detector that outputs both class and bounding box, a language model that predicts both next token and entity tag, an autoencoder that reconstructs both pixels and depth. The standard pattern is a shared trunk + several task-specific heads, trained with a weighted sum of the per-task losses.
┌──── head_a ──── out_a ──── loss_a (vs y_a)
x ── trunk ┤
└──── head_b ──── out_b ──── loss_b (vs y_b)
combined_loss = α * loss_a + (1 - α) * loss_b
The trunk’s gradients receive contributions from BOTH heads, so the trunk learns features useful for both tasks at once. The heads specialize.
The model
class TwoHead(nnx.Module):
def __init__(self, d_in, hidden, rngs):
self.trunk = nnx.Linear(d_in, hidden, rngs=rngs)
self.head_a = nnx.Linear(hidden, 1, rngs=rngs)
self.head_b = nnx.Linear(hidden, 1, rngs=rngs)
def __call__(self, x):
h = jax.nn.relu(self.trunk(x))
return self.head_a(h), self.head_b(h)
A model that returns a tuple is fine in nnx. nnx.value_and_grad
only requires the LOSS to be a scalar; the model’s output is
free-form.
The loss
def loss_fn(model, x, y_a, y_b):
out_a, out_b = model(x)
loss_a = jnp.mean((out_a - y_a) ** 2)
loss_b = jnp.mean((out_b - y_b) ** 2)
return alpha * loss_a + (1.0 - alpha) * loss_b
Notice that loss_a and loss_b are computed inside loss_fn
so they’re part of the differentiation graph. After
nnx.value_and_grad, the grads on the trunk are
α * grads_from_loss_a + (1-α) * grads_from_loss_b — the chain
rule handles the weighting automatically.
Choosing α
The simplest choice is a fixed scalar, like α = 0.5 for equal
weighting. In practice people use:
- Heuristics: weight by how “important” each task is.
- Loss balancing: scale each loss by its current value so the gradient magnitudes are comparable.
- Learned weights: parameterize α and learn it (the “uncertainty weighting” of Kendall et al.).
For this problem α is given as a fixed input.
Common pitfalls
-
Returning
loss_aandloss_bseparately and summing OUTSIDEloss_fn. Thenvalue_and_gradonly sees one of them and the grads are wrong. Combine inside the function being differentiated. -
Not casting α to float. If it arrives as
jnp.float32, the1.0 - αsubtraction works, but Python floats interact more cleanly. Cast at the boundary. - Using one head for both tasks. That defeats the entire point; the trunk would learn a single representation tied to a single output, with no task specialization.
Problem
Implement multi_task_loss(seed, x, y_a, y_b, alpha):
-
Define
TwoHead(nnx.Module)withtrunk: nnx.Linear(D, 8),head_a: nnx.Linear(8, 1),head_b: nnx.Linear(8, 1).__call__(self, x)returns(out_a, out_b)after a ReLU on the trunk. -
Build with
nnx.Rngs(int(seed))and wrap innnx.Optimizer(model, optax.sgd(0.05), wrt=nnx.Param). -
Define
loss_fn(model, x, y_a, y_b)that returnsα * MSE(out_a, y_a) + (1 - α) * MSE(out_b, y_b). - ONE optimizer step.
-
Recompute
loss_a_final,loss_b_final, and the combined value at the new params. -
Return
jnp.array([combined, loss_a_final, loss_b_final])as 1-D(3,).
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D). -
y_a,y_b: each 2-D(N, 1). -
alpha: float in[0, 1].
Output: 1-D (3,) — [combined_loss, loss_a, loss_b] after one step.
Hints
Sign in to attempt this problem and view the solution.