🦔

【Focal Loss】クラス不均衡に対応する効果的な損失関数

2025/03/05に公開

はじめに

Focal Loss[1]とは、2017年にFacebook AI Research (FAIR) によって提案された損失関数で ICCV 2017 に採択されています。主に物体検出タスクにおけるクラス不均衡問題を解決するために設計されました。本記事では、なぜFocal Lossが必要とされたのか、その理論的背景と導出、そして実際の応用について解説します。

なぜFocal Lossを調べたか

SAM(Segment Anything Model)論文を読んでいる際に、SAMの損失関数がFocal LossとDice Lossの組み合わせであることを知りました[2]

Losses and training. We supervise mask prediction with the linear combination of focal loss [65] and dice loss [73] used in [14].

Dice Lossについては理解していましたが、Focal Lossについては知らなかったため、詳しく調査しました。

背景:物体検出手法

Two-Stage Object Detection

物体検出の分野では、従来から高精度なモデルの多くがR-CNN(Regions with CNN features)をベースとしたTwo-Stage Object Detectorのアーキテクチャを採用しています。

R-CNNは2014年に発表された論文で、Two-Stage Detection Algorithmを提唱した先駆的な研究です[3]。この手法は以下の3つのステージで構成されています:

  1. Region Proposals生成: モデルは画像中の物体候補領域を提案します
  2. 特徴抽出: 各候補領域からCNN特徴量を抽出します
  3. 分類: 抽出された特徴量をFully Connected Layerに通し、SVMを用いて各候補がどのカテゴリに属するかを分類します

R-CNN

当時は、このTwo-Stage Object Detectionアーキテクチャを持つモデルがSOTAの性能を達成していました。しかし、この手法には「学習・推論コストと時間が高い」というデメリットがありました。

One-Stage Object Detection

一方、YOLO(下記図[4])やSSDなどのOne-Stage Object Detectorは、高速に物体検出を行うネットワークとして提案されていました。しかし、当時はTwo-Stage Object Detectorと比較して精度が劣るという課題がありました。

本論文の著者らは、「One-Stage Object DetectorがTwo-Stage Object Detectorと同等の精度を出せない主な原因は背景・前景クラス間の不均衡にある」という仮説を立てました。

クラス不均衡問題

一般的な物体検出タスクでは、画像のほとんどのピクセルは背景(background)に属します。検出対象となる前景(foreground)に属するピクセル数は、背景と比較すると圧倒的に少ないのです。つまり、1枚の画像内でforegroundとbackgroundの比率は極めて不均衡になっています。

(Two-Stage Object Detectorは、Region Proposalとして候補となったオブジェクトのみを処理することで、この不均衡を解消しています。)

損失関数の進化:Cross EntropyからFocal Lossへ

Cross Entropyを出発点とし、上記の問題を解決するFocal Lossを導出します。

Cross Entropy (CE)

まず、通常のCross Entropyについて定式化します。バイナリ分類の場合

\text{Binary Cross Entropy}(p, y) = \begin{cases} -\log(p) & \text{if } y = 1 \\ -\log(1-p) & \text{otherwise} \end{cases}

ここで、y \in \{0, 1\}は真のラベル、p \in [0, 1]はモデルが予測した確率値です。

これを簡略化するために、p_tを以下のように定義します

p_t = \begin{cases} p & \text{if } y=1 \\ 1-p & \text{otherwise} \end{cases}

すると、Cross Entropyは次のように書き換えられます

\text{CE}(p, y) = \text{CE}(p_t) = -\log(p_t)

マルチクラス分類の場合は

\text{Multiclass Cross Entropy}(p_t) = -\sum_i p_i \cdot \log(q_i)

ここで
p_i:真の確率分布(例:[1, 0, 0])
q_i:予測された確率分布(例:[0.25, 0.67, 0.08])

クラス不均衡問題とCEの限界

Cross Entropyでは、高確率(0.6-1.0)で正しく分類された「簡単な例」でも、損失値は0.4-0.6程度と小さくはありません。そのため、大量のbackground(簡単な負例)の損失が蓄積されると、少数のforeground(難しい正例)の損失よりも支配的になってしまいます。

これにより以下のような問題が生じます

  1. 簡単な負例の過剰な影響: 多数の簡単な負例(背景)の寄与が支配的になり、難しい例や正例からの勾配が相対的に小さくなる
  2. 損失関数の問題: モデルが正例を適切に分類できていなくても、全体の損失関数の値は小さくなり、学習が適切に進まなくなる

Balanced Cross Entropy

上記の問題に対処するために、Balanced Cross Entropyが考案されました

\text{Balanced CE}(p_t) = -\alpha_t\log(p_t)

ここで、\alpha \in [0, 1]はクラスの重みパラメータで、クラス1(y = 1)には\alphaを、それ以外(y = -1)には1 - \alphaを割り当てます。このパラメータは通常、各クラスの出現頻度や交差検証によって経験的に決定されます。

Balanced Cross Entropyの限界

クラス不均衡に対する一般的なアプローチとして、クラスバランスに従って異なる重み(\alpha)を付ける方法がありますが、これは単に正例と負例の数の不均衡を調整するだけで、「分類が簡単な例」と「分類が難しい例」を区別しません。

著者らが注目した真の課題は、単なるクラス不均衡の解消ではなく、background(分類が簡単な例)の損失値が膨大となり、foreground(分類が難しい例)の損失値の寄与が小さくなるために、検出したい物体に対する学習が進まない現象です。これは従来のBalanced Cross Entropyでは解決できない問題です。

Focal Loss

Focal Lossは、分類が簡単な例の損失値を小さく、困難な例の損失値を大きくするように設計された損失関数です。形式的には、Cross Entropy Lossに調整係数(1 - p_t)^\gammaを追加したものです:

\text{Focal Loss}(p_t) = -(1 - p_t)^\gamma \log(p_t)

ここで \gamma ≧ 0 はパラメータで、簡単な例の損失をどの程度減衰させるかを決定します。簡単に分類に成功している例では(1 - p_t)^\gammaが小さい値になるため、損失への寄与が小さくなります。

下記図は\gammaを段階的に変化させたときの損失関数の形です。well-classified examplesの部分が「分類が簡単な例」の損失関数の値です。\gammaが大きくなるにつれて well-classified examplesのprobabilityが小さくなっていることがわかります。

Focal Lossの効果

Focal Lossの論文では RetinaNet というOne-Stage Object detectionモデルを提案し、 Focal Loss を用いて学習させています。

精度を既存の One-Stage object detector と比較したのが次の図です。当時のTwo-Stage modelと比較して同等かそれ以上の性能を発揮しています。

論文の実験では\gamma = 2が最適値として採用されています。

まとめ

Focal Lossは、物体検出タスクにおけるクラス不均衡問題を効果的に解決するために設計された損失関数です。Cross Entropyを基に、簡単な例の損失寄与を減らし、難しい例に焦点を当てることで学習効率を高めます。

SAMは小さいオブジェクトも学習データに採用しており、まさにFocal Lossは"Segment Anything"の目的に即した損失関数であることがわかりました。実際、SAMはzero-shotの推論においても小さなオブジェクトのセグメンテーションに成功しています(SAM論文 - Figure2)[2]。

また、この損失関数は物体検出だけでなく、クラス不均衡を含む様々な機械学習タスクにも応用可能で、一般的な分類タスクにも有効な、汎用性の高い損失関数であることが伺えます。

参考文献

  1. Focal Loss for Dense Object Detection
  2. Segment Anything
  3. Rich feature hierarchies for accurate object detection and semantic segmentation
  4. You Only Look Once: Unified, Real-Time Object Detection
  5. Qiitaの解説記事

Discussion