🐶

【論文5分まとめ】Deep Mutual Learning

2021/12/22に公開

概要

通常の蒸留のような教師と生徒の間で行われる学習とは異なり、生徒同士で協力して学習する枠組みであるDeep Mutual Learningを提案している。

書誌情報

  • Zhang, Ying, et al. "Deep mutual learning." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018.
  • https://arxiv.org/abs/1706.00384

ポイント

2つのネットワークの訓練

下図に示すように、通常の分類損失に加え、2つのネットワーク\Theta_{1}, \Theta_{2}の出力である確率分布のKL距離を損失に追加している。

L_{\Theta_{1}}=L_{C_{1}}+D_{K L}\left(\boldsymbol{p}_{2} \| \boldsymbol{p}_{1}\right)
L_{\Theta_{2}}=L_{C_{2}}+D_{K L}\left(\boldsymbol{p}_{1} \| \boldsymbol{p}_{2}\right)

このようなシナリオは、より多くの生徒ネットワーク\Theta_{1}, \Theta_{2}, \ldots, \Theta_{K}(K \geq 2)が存在している時にも拡張できる。しかし、K>2 の時は、各ネットワークの損失として2つの方法が考えられる。

  • 1つ目は、自分以外のネットワークの出力とのKL距離をそれぞれ出し、それらを平均する方法である。
  • 2つ目は、自分以外のネットワークの出力の平均とのKL距離を出す方法である。

それぞれの損失関数は以下のようになる。

L_{\Theta_{k}}=L_{C_{k}}+\frac{1}{K-1} \sum_{l=1, l \neq k}^{K} D_{K L}\left(\boldsymbol{p}_{l} \| \boldsymbol{p}_{k}\right)
L_{\Theta_{k}}=L_{C_{k}}+D_{K L}\left(\boldsymbol{p}_{a v g} \| \boldsymbol{p}_{k}\right), \quad \boldsymbol{p}_{a v g}=\frac{1}{K-1} \sum_{l=1, l \neq k}^{K} \boldsymbol{p}_{l} .

先に結論を述べておくと、前者の方法の方がより良い精度が得られることが、実験によって明らかになっている。

実験

実験は、CIFAR-100およびMarket-1501で行っている。以下の疑問への回答が示されている。

  • DMLの効果はどうか?
  • DMLは蒸留よりも優れているのか?
  • ネットワークの数を増やした時の効果は?
  • ネットワークの数が多い時の損失はどうすべきか?

DMLの効果

以下は、CIFAR-100での実験結果を表した表である。DMLはそれぞれのネットワークを個別に訓練した時よりもより良い精度が得られる、ということがわかる。

同様に、Market-1501での結果は以下の表のようになっている。こちらでも、DMLの導入により、精度が一貫して向上することが示されている。

蒸留とDMLの比較

一般的な蒸留と同じように、1つ目のネットワークを教師モデルとして、2つ目のネットワークを生徒モデルとして訓練している。その場合とDMLを使用して訓練したモデルと比較すると、DMLを使用した場合の方が良い精度が得られている。

ネットワーク数の影響

(a)ネットワーク数を増加させたとき、個別に訓練した場合は、その平均精度は当然横ばいになる。一方で、DMLを用いると同時に訓練するネットワークの数を増やすことで、明らかに個別に訓練した場合よりも高精度になることがわかる。

(b)複数のネットワークを個別に訓練しても、その出力を平均してアンサンブルしてあげれば、一般的に精度は向上する。DMLでもアンサンブルは有効で、個別に訓練したモデルのアンサンブルよりもDMLで訓練したモデルのアンサンブルの方が良い精度が得られている。

ネットワーク数が多いときの損失関数

ネットワークの数Kが2よりも大きいとき、先に述べたように、1. 自分以外のネットワークの出力とのKL距離をそれぞれ出してそれらを平均する方法(DML)と、2. 自分以外のネットワークの出力の平均とのKL距離を出す方法(DML_e)の2種類が考えられる。先に記したように、前者の方が高い精度を実現できることが実験的に明らかになった。後者の方法は、自分以外のネットワークのアンサンブルをターゲットとしていると言える。

そもそも蒸留は、通常のone-hotターゲットでは与えられない2位以下の確率を教師信号として与えることによって、モデルを大域最適化することができる。DMLではその効果がさらに強く与えられることで、大域最適解が得られていると考えられる。

一般に、アンサンブルしたモデルは高い精度を実現できるが、DML_eの場合はこれが裏目に出ている。DML_eがDMLに比べて悪い精度になってしまうのは、アンサンブルによって得られる確率分布が鋭いピークをもつため、DMLが本来与えられるはずの2位以下の重要な情報を低減させてしまうことが原因と考えられる。

Discussion