🔥

【ML】Focal Loss explained

2024/04/25に公開

1. Focal Loss

Focal Loss is loss function designed to address class imbalance in machine learning, particularly in object detection and image classification tasks.
It extends the traditional cross-entropy loss by focusing hard examples-- those that more difficult for the model to calssify--while reducing the contribution of easy examples.

This way, it helps models learn from challenging samples, which can lead to improved performance in datasets with imbalanced class distributions.

2. How It Works

Focal Loss introduces a tunable parametor, \gamma which control degree to which hard examples are focused upon.
When \gamma = 0, the focal loss function is equivalent to standard crros entropy loss. As \gamma increases, the focus shifts towards hard emaples.

3. Formula

Formula of Focal Loss(FL) and traditional binary cross-entropy(BCE) are below:
・traditional cross-entropy(BCE)
CE(P_t) = -log(p_t) × 1
・Focal Loss
FL(p_t) = -(1 - p_t)^{\gamma}log(p_t) × 1

Here, p_t represents probability of the correct class, as predicted by the model. The term (1 - p_t)^{\gamma} down-weights examples that are correctly classified(easy example), allowing the model to concentrate more on misclassied samples.

Looking at the diagram above, as \gamma increases, loss becomes more tolerant. This can suppress the model learn strongly from information from simple samples that are well predicted, such as [p=0.6].

4. When to Use

・Class Imbaranced
When the dataset has a significant imbalance between classes, Focal Loss can be useful.
・Object Detection
It is commonly used in tasks like object detection, where certain classes might be underrepresented.

Reference

(1)Original Paper
(2)Focal Loss (Papers With Code)

Discussion