【論文紹介】Confident Learning(Northcutt et al. 2020)

7 min read読了の目安(約6600字

こんにちは、竹輪内からしです。ラベルに誤りがあるデータでの学習や、学習していないデータに対する推論に興味があって、現在勉強中です。
今回はラベルの誤りに関する論文を読みましたので、紹介します。

論文の情報

  • タイトル: Confident Learning: Estimating Uncertainty in Dataset Labels
  • 著者:
    • Curtis G. Northcutt(Massachusetts Institute of Technology, Department of EECS)
    • Lu Jiang(Google Research)
    • Isaac L. Chuang(Massachusetts Institute of Technology, Department of EECS)
  • 公開日: Fri, 4 Sep 2020 08:12:20 UTC
  • リンク:

概要

  • 訓練データセットからnoisy labelsを除去して学習することで、noisy labelsにロバストに学習する。
  • 付与されているラベルと真のラベルの同時確率を用いてnoisy labelsをモデリングする。
  • 理論的な保証があり、なおかつstate of the art。(この記事では証明については割愛します。)

noisy labelsとは何か

訓練データセットの正解ラベルが誤っているデータ(例:犬の画像にもかかわらず「猫」というラベルが付与されている)のことです。Noisy labelsが訓練データセット中に存在すると、Noisy labelsが訓練データセット中にない場合より精度が低くなってしまうことが報告されています[1]

Noisy labelsを含む訓練データセットを扱う上で、この論文は2つの課題を考えます。

  • どのようにラベルが誤っているデータを特定するか?
  • noisy labelsを含む訓練データセットでも精度よく学習するにはどうすればよいか?

uncertaintyとは何か

タイトルにもある「uncertainty」という言葉について説明します。

その前に、まず分類問題について簡単に説明します。
分類問題ではデータセットにあるデータをできるだけ正しく分けられる決定境界を学習することを目指します。イメージ図を次に示します。
aleatoric uncertainty
分類問題の概要
この例では、決定境界の左側にあるデータを"赤"クラス、右側にあるデータを"青"クラスと判定します。

この決定境界を基にデータのクラスを判定すると、予測を外してしまうことがあります。
上の図では2つの円が重なっているところでは、統計的に分類することが困難なため、正しく予測できるとは限りません。
「どれだけ予測を外しそうか」というのがuncertaintyです。

予測を外す原因は上のような統計的に分類できないことの他にもあります。
例えば、次の図のような場合を考えてみましょう。
epistemic uncertainty
他のuncertaintyの例
上の図のように、手元のデータセットにはないけれど、真には存在するようなデータがあるとしましょう。
図のような場合だと、決定境界の左側にあるデータを"赤"クラス、右側にあるデータを"青"クラスと判定してしまうと、誤ってしまうデータがあります。
このように、訓練データセットにない知らないデータだったために予測を外してしまうことがあります。

この論文が注目しているのは、訓練データセットにはあったけれど、統計的に分類することが困難なため生じたuncertainty(aleatoric uncertainty)です。

論文で解きたい問題


この論文では、真のラベルy^*と実際に付与されているラベル\tilde{y}があると考えます。
この\tilde{y}は正しいラベルかもしれませんし誤っているラベルかもしれません。
真のラベルy^*と実際に付与されているラベル\tilde{y}がclass conditional classification noise process(CNP)[2]によって確率p(\tilde{y}=i \mid y^*=j)次のように射影されると仮定します。ただしi, jは任意のクラスラベルです。

y^* \rightarrow \tilde{y}

この論文では、データ\bm{x}には依存せず、真のラベルy^*に応じてクラスを誤ると仮定しています。

もし、データ\bm{x}に付与されているラベルが\tilde{y}のときに、真のラベルがy^*である確率p(y^* \mid \tilde{y})が高ければ、ラベルが誤っている確率が高いといえます。
同様に、データ\bm{x}の真のラベルがy^*の時に、データ\bm{x}に実際についているラベルが\tilde{y}である確率p(\tilde{y} \mid y^*)が高くてもラベルが誤っている確率が高いといえます。
ただし、いずれもy^* \neq \tilde{y}です。
これらの確率を総合して考えるために、真のラベルと付与されているラベルの同時確率p(y^*,\tilde{y})によってラベルの誤りをモデル化します。
この確率p(y^*,\tilde{y})y^* = \tilde{y}の時に低ければ付与されているラベルが誤りである可能性が高いです。
論文では\bm{Q}_{\tilde{y},y^*}=p(y^*,\tilde{y})を推定して、ラベルが誤っているかを判定します。

アプローチ


Confident learningでは次の3ステップでラベルが誤っているかを判定します。

  1. Count: \bm{Q}_{\tilde{y},y^*}を推定する
  2. Rank: 判定するための指標を求める
  3. Prune: 閾値処理により判定する

次の図にConfident learningの概要を示します。
concept of confident learning
Confident learningの概要

Count


まずはCountのステップについて見ていきます。
真のラベルと付与されたラベルの同時確率\bm{Q}_{\tilde{y},y^*}は、どのようなラベル誤りが起こりやすいかを見るための確率であると考えられます。

どのようなラベル誤りが起こりやすいかを見るための方法は、同時確率を推定する他に存在します。
それは、誤りの個数を数える方法です。
例えば、「真のラベルは『1』なのに実際は『3』のラベルがつけられているデータが10個あった」などです。
この個数を用いても、同時確率と同様に、どのようなラベル誤りが起こりやすいかを調べることができます。
違いとしては、個数は正規化されていないのに対して、同時確率は[0, 1]の範囲に正規化されています。そのため、データの総数次第である個数の意味合いは変化するのに対して、同時確率はデータの総数が多くても少なくても同じ値は同じ意味を表します。

上の個数をC_{\tilde{y},y^*}と書くことにします。(上の例はC_{3,1}=10と書けます。)
では、このC_{\tilde{y},y^*}はどのように求めるのでしょうか?
論文では、\tilde{y}=iにも拘わらず、学習済みの分類器でj (i \neq j)と判定されるデータの個数を数えます。このとき、分類器がある程度自信を持ってjと予測するときにのみカウントします。

論文の式(1)の意味を整理すると

  • \bm{x} \in \bm{X}_{\tilde{y}=i}: 実際につけられているラベルがiとなるデータ
  • j= \argmax_{l \in [m]: \hat{p}(\tilde{y}=l;\bm{x},\bm{\theta})} \hat{p}(\tilde{y};\bm{x},\bm{\theta}): 分類器は、予測される確率が一番高いクラスjを予測として出力する
  • \hat{p}(\tilde{y}=j;\bm{x},\bm{\theta}) \geq t_j: 閾値t_jの確率で分類器でラベルがjと予測される

となるデータの集合\hat{\bm{X_{\tilde{y}=i, y^*=j}}}の個数を数えたものが\bm{C}_{\tilde{y}=i, y^*=j}[i][j]である
と解釈できます。

ここで、閾値の決め方が気になると思います。
閾値は、実際のラベルが\tilde{y}=jであるデータがjと予測される確率の期待値として求めます。これは次式のように書けます。

t_j = \frac{|\bm{X}_{\tilde{y}=j}|}{1} \sum_{\bm{x} \in \bm{X}_{\tilde{y}=j}} \hat{p}(\tilde{y}=j; \bm{x}, \bm{\theta})

さて、ここで求めた\bm{C}_{\tilde{y},y^*}を次の式のように正規化することで\bm{Q}_{\tilde{y},y^*}を求めることができます。

\bm{Q}_{\tilde{y},y^*} = \frac{\sum_{i \in [m],j \in [n]}(\frac{\sum_{j \in [m]}\bm{C}_{\tilde{y}=i, y^*=j}}{\bm{C}_{\tilde{y}=i, y^*=j}} \cdot |\bm{X_{\tilde{y}=i}}|)} {\frac{\sum_{j \in [m]}\bm{C}_{\tilde{y}=i, y^*=j}}{\bm{C}_{\tilde{y}=i, y^*=j}} \cdot |\bm{X_{\tilde{y}=i}}|)}

次のRank&PruneのステップではC_{\tilde{y},y^*}を使う方法と\bm{Q}_{\tilde{y},y^*}を使う方法があります。

Rank & Prune

RankステップとPruneステップにより、ラベルが誤りかを判定します。
論文ではいくつかの方法が挙げられています。

  • C_{\tilde{y},y^*}を用いる方法
    • 予測ラベルと実際のラベルが異なる場合にラベル誤りと判定する方法(CL baseline 1)
    • 推定したC_{\tilde{y},y^*}の非対角成分をもとにラベル誤りか判定する方法(CL method 2)
  • \bm{Q}_{\tilde{y},y^*}を用いる方法
    • iのラベルがついていて予測もiとなるデータの中で確率が最も低いものをラベル誤りと判定する方法(CL method 3)
    • 真のラベルの予測確率と実際についているラベルの予測確率の差が大きいものをラベル誤りと判定する方法(CL method 4)
    • 上の2つの組み合わせ(CL method 5)

実験と結果

提案手法の有効性を検証するためにCIFAR-10のデータセットを用いて実験します。
CIFAR-10は画像認識のデータセットです。

ここで、CIFAR-10にはラベルの誤りがないため、人工的にラベルを誤らせます。
今回の実験では、noise transition matrixを用いて誤ったラベルを作ります。
noise transition matrixというのは「どのクラスがどのクラスに誤りやすいか」を表した行列です。
このnoise transition matrixを作る上で変える要素が2つあります。

  • noise rate: どれくらいの確率でラベルを誤るか。値が高いほどラベルの誤りが多い。
  • sparsity: どれくらい特定のクラスとラベルを誤るか。値が高いほど少数のクラスとの間でクラスを誤る。

まずは、分類性能を見てみましょう。提案法ではラベルが誤っていると判定されたデータを除いて学習しています。
accuracy
分類性能の比較
提案手法(中央の横線より上)では、ラベル誤りが多くても精度よく学習できていることがわかります。
また、赤帯の部分のようにsparsityが大きく異なっていても、提案手法では精度さが小さく抑えられています。

続いて、どれだけ正しくラベル誤りを判定できたか見てみましょう。
detection performance
ラベル誤り判定の性能の比較
Recallの値が高いことから、誤ったラベルの見逃しが少ないことがわかります。

では、推定した\bm{Q}_{\tilde{y},y^*}はよかったのでしょうか?
estimated joint distribution
\bm{Q}_{\tilde{y},y^*}の推定性能の評価
一番左の図が真の\bm{Q}_{\tilde{y},y^*}です。つまり、データセットはこの確率でラベル誤りがあります。
推定した\bm{Q}_{\tilde{y},y^*}が中央の図です。左の図と比較すると結構似ていることがわかります。
一番右の図は真の\bm{Q}_{\tilde{y},y^*}と推定した\bm{Q}_{\tilde{y},y^*}の差を表しています。この図からもほとんど差がないことがわかります。

脚注
  1. A Closer Look at Memorization in Deep Networks ↩︎

  2. Learning From Noisy Examples ↩︎