🤖

A survey of loss functions for semantic segmentation

2023/09/24に公開

Introduction

セグメンテーションについて

  • semantic segmentation
    • クラス内の個体識別は行わない
    • 通常画像内のすべての物体に対してピクセルレベルで分類を実施する
  • instance segmentation
    • クラス内の個体識別を行う
    • 画像内の特定のクラスのみを分類する
  • papnoptic segmentation
    • semantic + instance segmentation で、画像内すべての物体を個体識別してセグメンテーションする

Binary Cross Entropy

通常二値分類のクラス分類タスクで幅広く使用されているが、セグメンテーションにおいてもピクセル単位で pos/neg を判定する用途で使用することができる。p(x) が真値、q(x) が予測値であるとすると

L = - \sum_x p(x) \log q(x) = -( y\log \hat{y} + (1-y)\log(1-\hat{y}))

と表される。

Weighted Binary Cross-Entropy

BCE では pos/neg の正誤判定を平等に扱うため、クラス内の偏り(基本的には negative が多くて、画像内の一部分だけをセグメンテーションするようなタスク)に弱いという側面がある。そこで positive sample に重み付けしたものがこの損失関数。

L = -( \beta * y\log \hat{y} + (1-y)\log(1-\hat{y}))

Balanced Cross-Entropy

Weghted Binary Cross-Entropy に加えて negative 側も重み付けしたもの。

L = -( \beta * y\log \hat{y} + (1-\beta) * (1-y)\log(1-\hat{y}))

Focal Loss

Binary Cross Entropy の亜種であ

Binary Cross Entropy では [0, 1] を予測するタスクであるということを思い出して、y=1 のときには -\log \hat{y}y=0 のときには -\log(1-\hat{y}) であるために、y=1 のときに \hat{y}y=0のときに 1-\hat{y} をとる p_t という変数を作ると

CE = -log p_t

と簡略化することができる。Focal Loss はこれに対して重み付けを行ったものであり、次のような式である

L = - \alpha_t(1-p_t)^\gamma \log p_t

Dice Loss

2つの画像の類似度を図る指標として Dice coefficient が広く用いられており、それを損失関数に取り入れたもの。Dice coeff. は以下の式で定義される:

D = \frac{2|A \cap B|}{|A| + |B|} = \frac{2 y\hat{y}}{ y + \hat{y}}

Dice coeff. は [0, 1] の範囲に分布する係数で、この係数が大きいほど2つの画像(=ピクセル集合)の類似度は高いと考えることができる。ただし片方がもう片方に内包されているときには Dice coeff. は1にはならないため、タスクによっては他の metrics の使用も検討する必要がある。

Dice Loss は以下の式で定義される:

L = 1 - \frac{2 y\hat{y} + 1}{ y + \hat{y} + 1}

分母分子の「1」は y=\hat{y}=0 の場合にゼロ除算にならないように付け足しているもの。

Tversky Loss

Dice coeff. を一般化したものとして Tversky index というものがあり、以下で定義されている:

TI = \frac{y\hat{y}}{y\hat{y} + \beta(1-y)\hat{y} + (1-\beta)y(1-\hat{y})}

\beta=1/2 のときには Dice coeff. と一致する。Tversky Loss は Dice coeff. と同様にゼロ除算に注意して

L = 1 - \frac{1 + y\hat{y}}{1 + y\hat{y} + \beta(1-y)\hat{y} + (1-\beta)y(1-\hat{y})}

で定義される。

Sensitivity Specificity Loss

Dice coeff. と同様に幅広くセグメンテーションタスクで使用されるメトリクスに sensiticity (=recall) と specificity がある。メトリクスを直接損失関数と使用しているもので

L = w\times sensitivity + (1-w) \times specificity

と定義されている。

結論

  • セグメンテーションタスクはなかなかに複雑なタスクなので損失関数の設計も一筋縄ではいかない
  • highly-imbalanced data であれば Focal 系、balanced data であれば Binary Cross Entropy、middly skewed data であれば dice coeff 系が初手でよさそう

参考文献

次に気になる論文

GitHubで編集を提案

Discussion