【Focal Loss】クラス不均衡に対応する効果的な損失関数
はじめに
Focal Loss[1]とは、2017年にFacebook AI Research (FAIR) によって提案された損失関数で ICCV 2017 に採択されています。主に物体検出タスクにおけるクラス不均衡問題を解決するために設計されました。本記事では、なぜFocal Lossが必要とされたのか、その理論的背景と導出、そして実際の応用について解説します。
なぜFocal Lossを調べたか
SAM(Segment Anything Model)論文を読んでいる際に、SAMの損失関数がFocal LossとDice Lossの組み合わせであることを知りました[2]
Losses and training. We supervise mask prediction with the linear combination of focal loss [65] and dice loss [73] used in [14].
Dice Lossについては理解していましたが、Focal Lossについては知らなかったため、詳しく調査しました。
背景:物体検出手法
Two-Stage Object Detection
物体検出の分野では、従来から高精度なモデルの多くがR-CNN(Regions with CNN features)をベースとしたTwo-Stage Object Detectorのアーキテクチャを採用しています。
R-CNNは2014年に発表された論文で、Two-Stage Detection Algorithmを提唱した先駆的な研究です[3]。この手法は以下の3つのステージで構成されています:
- Region Proposals生成: モデルは画像中の物体候補領域を提案します
- 特徴抽出: 各候補領域からCNN特徴量を抽出します
- 分類: 抽出された特徴量をFully Connected Layerに通し、SVMを用いて各候補がどのカテゴリに属するかを分類します
当時は、このTwo-Stage Object Detectionアーキテクチャを持つモデルがSOTAの性能を達成していました。しかし、この手法には「学習・推論コストと時間が高い」というデメリットがありました。
One-Stage Object Detection
一方、YOLO(下記図[4])やSSDなどのOne-Stage Object Detectorは、高速に物体検出を行うネットワークとして提案されていました。しかし、当時はTwo-Stage Object Detectorと比較して精度が劣るという課題がありました。
本論文の著者らは、「One-Stage Object DetectorがTwo-Stage Object Detectorと同等の精度を出せない主な原因は背景・前景クラス間の不均衡にある」という仮説を立てました。
クラス不均衡問題
一般的な物体検出タスクでは、画像のほとんどのピクセルは背景(background)に属します。検出対象となる前景(foreground)に属するピクセル数は、背景と比較すると圧倒的に少ないのです。つまり、1枚の画像内でforegroundとbackgroundの比率は極めて不均衡になっています。
(Two-Stage Object Detectorは、Region Proposalとして候補となったオブジェクトのみを処理することで、この不均衡を解消しています。)
損失関数の進化:Cross EntropyからFocal Lossへ
Cross Entropyを出発点とし、上記の問題を解決するFocal Lossを導出します。
Cross Entropy (CE)
まず、通常のCross Entropyについて定式化します。バイナリ分類の場合
ここで、
これを簡略化するために、
すると、Cross Entropyは次のように書き換えられます
マルチクラス分類の場合は
ここで
クラス不均衡問題とCEの限界
Cross Entropyでは、高確率(0.6-1.0)で正しく分類された「簡単な例」でも、損失値は0.4-0.6程度と小さくはありません。そのため、大量のbackground(簡単な負例)の損失が蓄積されると、少数のforeground(難しい正例)の損失よりも支配的になってしまいます。
これにより以下のような問題が生じます
- 簡単な負例の過剰な影響: 多数の簡単な負例(背景)の寄与が支配的になり、難しい例や正例からの勾配が相対的に小さくなる
- 損失関数の問題: モデルが正例を適切に分類できていなくても、全体の損失関数の値は小さくなり、学習が適切に進まなくなる
Balanced Cross Entropy
上記の問題に対処するために、Balanced Cross Entropyが考案されました
ここで、
Balanced Cross Entropyの限界
クラス不均衡に対する一般的なアプローチとして、クラスバランスに従って異なる重み(
著者らが注目した真の課題は、単なるクラス不均衡の解消ではなく、background(分類が簡単な例)の損失値が膨大となり、foreground(分類が難しい例)の損失値の寄与が小さくなるために、検出したい物体に対する学習が進まない現象です。これは従来のBalanced Cross Entropyでは解決できない問題です。
Focal Loss
Focal Lossは、分類が簡単な例の損失値を小さく、困難な例の損失値を大きくするように設計された損失関数です。形式的には、Cross Entropy Lossに調整係数
ここで
下記図は
Focal Lossの効果
Focal Lossの論文では RetinaNet というOne-Stage Object detectionモデルを提案し、 Focal Loss を用いて学習させています。
精度を既存の One-Stage object detector と比較したのが次の図です。当時のTwo-Stage modelと比較して同等かそれ以上の性能を発揮しています。
論文の実験では
まとめ
Focal Lossは、物体検出タスクにおけるクラス不均衡問題を効果的に解決するために設計された損失関数です。Cross Entropyを基に、簡単な例の損失寄与を減らし、難しい例に焦点を当てることで学習効率を高めます。
SAMは小さいオブジェクトも学習データに採用しており、まさにFocal Lossは"Segment Anything"の目的に即した損失関数であることがわかりました。実際、SAMはzero-shotの推論においても小さなオブジェクトのセグメンテーションに成功しています(SAM論文 - Figure2)[2]。
また、この損失関数は物体検出だけでなく、クラス不均衡を含む様々な機械学習タスクにも応用可能で、一般的な分類タスクにも有効な、汎用性の高い損失関数であることが伺えます。
Discussion