🏋️

Uncertainty Weightingとは

に公開

元論文: Kendall, A., Gal, Y., & Cipolla, R. (2018). Multi-task learning using uncertainty to weigh losses for scene geometry and semantics. CVPR 2018.

機械学習モデルに複数のタスクを同時に学習させたいとき、どうやってバランスを取るべきでしょうか?この問題に対する対策として、「Uncertainty Weighting」があります

マルチタスク学習の課題

深層学習モデルを訓練する際、複数の目的を同時に達成したいケースは珍しくありません。例えば、自動運転システムでは物体検出と深度推定を同時に行いたいですし、画像処理では分類とセグメンテーションを同時に実行したい場合があります。
このとき直面する最大の課題は、各タスクの損失関数をどのような重みで組み合わせるかという問題です。単純に足し合わせると、スケールの大きいタスクが支配的になってしまい、うまく学習が進みません。

Uncertainty Weightingの基本アイデア

Uncertainty Weightingは、2018年にKendallらによって提案された手法で、各タスクの「不確実性」を考慮して重みを自動的に調整します。ここでの重要なことは、タスクごとの予測の不確実性を明示的にモデル化し、それを重み付けに活用するという点です。

理論的背景

まず、マルチタスク学習の損失関数を確率的な観点から定式化してみます。複数のタスクがある場合、全体の尤度は各タスクの尤度の積として表現できます:

p(y_1, y_2, \dots, y_K | \boldsymbol{x}, \boldsymbol{W}) = \prod_{i=1}^{K} p(y_i | \boldsymbol{x}, \boldsymbol{W})

ここで、y_iは各タスクの出力、\boldsymbol{x}は入力、\boldsymbol{W}はモデルのパラメータです。

各タスクの出力にガウシアンノイズを仮定すると:

p(y_i | \boldsymbol{f}^{\boldsymbol{W}}(\boldsymbol{x})) = \mathcal{N}(\boldsymbol{f}^{\boldsymbol{W}}(\boldsymbol{x}), \sigma_i^2)

ここで\sigma_i^2が各タスクの不確実性を表すパラメータとなります。

損失関数の導出

負の対数尤度を最小化することを考えると、最終的な損失関数は次のような形になります:

\mathcal{L} = \sum_{i=1}^{K} \left( \frac{1}{2\sigma_i^2} \mathcal{L}_i + \log \sigma_i \right)

ここで注目すべきポイントは:

  • \mathcal{L}_i : 各タスクの通常の損失関数
  • \frac{1}{2\sigma_i^2} : タスクiの重み(不確実性が大きいほど重みが小さくなる)
  • \log \sigma_i : 正則化項(\sigma_iが無限大になることを防ぐ)

直感的な理解

この手法の素晴らしい点は、不確実性の高いタスクには小さな重みを、確実性の高いタスクには大きな重みを自動的に割り当てることです。
例えば、画像から「物体の位置」と「物体までの距離」を同時に予測するケースを考えてみましょう。距離推定は本質的に難しいタスクなので、モデルは自然とこのタスクに対して大きな
\sigma(不確実性)を学習します。結果として、距離推定の損失が全体の学習を支配することを防げるのです。

実装例

PyTorchでの実装例は以下になります。
実装では一部理論式と異なる。これは\sigma^2が0に近づくと計算が不安定になることを防ぐためで、代わりに\log(\sigma^2)を学習可能なパラメータとして扱うのがテクニックらしい。

import torch
import torch.nn as nn

class UncertaintyWeightedLoss(nn.Module):
    def __init__(self, num_tasks):
        super().__init__()
        # log(σ^2)を学習可能パラメータとして定義
        self.log_vars = nn.Parameter(torch.zeros(num_tasks))
    
    def forward(self, losses):
        """
        losses: 各タスクの損失のリスト
        """
        weighted_losses = []
        for i, loss in enumerate(losses):
            precision = torch.exp(-self.log_vars[i])
            weighted_loss = precision * loss + self.log_vars[i]
            weighted_losses.append(weighted_loss)
        
        return sum(weighted_losses)

使用例

# モデルとタスク固有の損失関数を定義
model = YourMultiTaskModel()
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.CrossEntropyLoss()
uncertainty_loss = UncertaintyWeightedLoss(num_tasks=2)

# 訓練ループ
output1, output2 = model(input_data)
loss1 = criterion_task1(output1, target1)
loss2 = criterion_task2(output2, target2)

# Uncertainty Weightingを適用
total_loss = uncertainty_loss([loss1, loss2])
total_loss.backward()

実験的な知見とコツ

初期値の重要性

log_varsの初期値は0(つまり\sigma^2 = 1)から始めるのが一般的だが、タスクの特性によっては調整が必要な場合がある。

学習率の調整

不確実性パラメータの学習率は、モデル本体のパラメータとは別に設定することで、より安定した学習が可能になることがある。

モニタリング

学習中の各タスクの重み(\sigma^2)をモニタリングすることで、モデルがどのタスクを「難しい」と判断しているかを把握できる。これは、モデルの振る舞いを理解する上で非常に有用な情報である。

Discussion