Uncertainty Weightingとは
元論文: Kendall, A., Gal, Y., & Cipolla, R. (2018). Multi-task learning using uncertainty to weigh losses for scene geometry and semantics. CVPR 2018.
機械学習モデルに複数のタスクを同時に学習させたいとき、どうやってバランスを取るべきでしょうか?この問題に対する対策として、「Uncertainty Weighting」があります
マルチタスク学習の課題
深層学習モデルを訓練する際、複数の目的を同時に達成したいケースは珍しくありません。例えば、自動運転システムでは物体検出と深度推定を同時に行いたいですし、画像処理では分類とセグメンテーションを同時に実行したい場合があります。
このとき直面する最大の課題は、各タスクの損失関数をどのような重みで組み合わせるかという問題です。単純に足し合わせると、スケールの大きいタスクが支配的になってしまい、うまく学習が進みません。
Uncertainty Weightingの基本アイデア
Uncertainty Weightingは、2018年にKendallらによって提案された手法で、各タスクの「不確実性」を考慮して重みを自動的に調整します。ここでの重要なことは、タスクごとの予測の不確実性を明示的にモデル化し、それを重み付けに活用するという点です。
理論的背景
まず、マルチタスク学習の損失関数を確率的な観点から定式化してみます。複数のタスクがある場合、全体の尤度は各タスクの尤度の積として表現できます:
ここで、
各タスクの出力にガウシアンノイズを仮定すると:
ここで
損失関数の導出
負の対数尤度を最小化することを考えると、最終的な損失関数は次のような形になります:
ここで注目すべきポイントは:
-
: 各タスクの通常の損失関数\mathcal{L}_i -
: タスクiの重み(不確実性が大きいほど重みが小さくなる)\frac{1}{2\sigma_i^2} -
: 正則化項(\log \sigma_i が無限大になることを防ぐ)\sigma_i
直感的な理解
この手法の素晴らしい点は、不確実性の高いタスクには小さな重みを、確実性の高いタスクには大きな重みを自動的に割り当てることです。
例えば、画像から「物体の位置」と「物体までの距離」を同時に予測するケースを考えてみましょう。距離推定は本質的に難しいタスクなので、モデルは自然とこのタスクに対して大きな
実装例
PyTorchでの実装例は以下になります。
実装では一部理論式と異なる。これは
import torch
import torch.nn as nn
class UncertaintyWeightedLoss(nn.Module):
def __init__(self, num_tasks):
super().__init__()
# log(σ^2)を学習可能パラメータとして定義
self.log_vars = nn.Parameter(torch.zeros(num_tasks))
def forward(self, losses):
"""
losses: 各タスクの損失のリスト
"""
weighted_losses = []
for i, loss in enumerate(losses):
precision = torch.exp(-self.log_vars[i])
weighted_loss = precision * loss + self.log_vars[i]
weighted_losses.append(weighted_loss)
return sum(weighted_losses)
使用例
# モデルとタスク固有の損失関数を定義
model = YourMultiTaskModel()
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.CrossEntropyLoss()
uncertainty_loss = UncertaintyWeightedLoss(num_tasks=2)
# 訓練ループ
output1, output2 = model(input_data)
loss1 = criterion_task1(output1, target1)
loss2 = criterion_task2(output2, target2)
# Uncertainty Weightingを適用
total_loss = uncertainty_loss([loss1, loss2])
total_loss.backward()
実験的な知見とコツ
初期値の重要性
log_varsの初期値は0(つまり
学習率の調整
不確実性パラメータの学習率は、モデル本体のパラメータとは別に設定することで、より安定した学習が可能になることがある。
モニタリング
学習中の各タスクの重み(
Discussion