🎩

ArcFace の Additive Angular Margin Loss について

2021/09/10に公開

NLPer だけど Metric Learning を知る必要が出てきたので、自分の理解のためにまとめておく。


まずは普通の softmax + cross entropy loss の部分から。

x_i \in \mathbb{R}^{d}i 番目のデータを何らかのニューラルネットワークに入れて出てくるベクトル(特徴ベクトル)とする。CV 業界の場合は何次元なのかよく知らないが、例えば BERT だと d=768
通常の分類問題では、これに Linear Layer をかませて、クラス数 n の次元のベクトルに変換する。

W^T x_i + b \ (W \in \mathbb{R}^{d \times n}, b \in \mathbb{R}^{n})

b はバイアス項と呼ばれる。

この n 次元のベクトルは和が 1 になっていないので、ベクトルの各要素の値はそのまま予測確率を表したものにはならない。したがって、和が 1 になるように Softmax 関数をかける。
先程の Wj \ (1 \leq j \leq n) 列目のベクトルを W_j, bj 番目の要素を b_j と表記することにすると、Softmax 関数を使って、クラス k の予測確率は

\dfrac{\exp(W_k^T x_i + b_k)}{\sum_{j=1}^n \exp(W_j^T x_i + b_j)}

と書ける。Cross Entropy Loss では正解ラベル y_i のクラスの予測確率のみに - \log をあてはめた形になるから、

- \log \dfrac{\exp(W_{y_i}^T x_i + b_{y_i})}{\sum_{j=1}^n \exp(W_j^T x_i + b_j)}

である。
通常は、バッチによって訓練データサイズが異なっていても損失を正規化するために、バッチサイズで平均する(論文にある式はバッチサイズ N で平均されている)。


では、本題の Additive Angular Margin Loss へと進む。

まず、Linear Layer について、 W の各列ベクトル W_j のL2ノルムを 1 に, バイアス項の各要素 b_j を 0 に正規化する。
また、ネットワークから出てくるベクトル(特徴ベクトル) x_i についても L2 ノルムを一定の値 s に正規化する。

このとき、Softmax に入れる前の n 次元のベクトルの各要素の値(論文では logit と呼ばれている)は、x_iW_j のなす角を \theta_j として

W_j^T x_i + b_j = s \cos \theta_j

と書ける。

余談だが、s は、Knowledge Distillation(知識蒸留)などでお馴染みの温度付き Softmax の温度 T と関係がある(s=\frac{1}{T})。s が大きいと、一番大きな logit が Softmax を通した後により 1 に近い値になり他が 0 に近い値になることで、違いが強調されるので、Metric Learning のように分離境界をしっかり学習したい場合に適切である(論文では s=64 と置いている)。

さて、このように W_j^T x_i + b_j を角度をパラメータとする式 s \cos \theta_j に変形したことで、埋め込みは半径 s の超球面上に分布するようになる。

次に、Additive Angular Margin Loss では、正解ラベル y_i のクラスに対応する logit \theta_{y_i} だけ角度のマージン m を加える。

s \cos \theta_{y_i} \to s \cos (\theta_{y_i} + m)

これも、他のクラスとの分離境界をよりはっきり学習させる効果がある。
具体的に2クラス分類で考えると分かりやすい。マージンを加える前までは分離境界は

\begin{cases} \cos \theta_{y_1} > \cos \theta_{y_2} & (\mathrm{Class 1}) \\ \cos \theta_{y_1} < \cos \theta_{y_2} & (\mathrm{Class 2}) \end{cases}

であったのが

\begin{cases} \cos (\theta_{y_1} + m) > \cos \theta_{y_2} & (\mathrm{Class 1}) \\ \cos \theta_{y_1} < \cos (\theta_{y_2} + m) & (\mathrm{Class 2}) \end{cases}

になるということである。仮に全部の角を [0, \pi] の区間で考えると \cos は単調減少関数だから

\begin{cases} \theta_{y_1} < \theta_{y_2} & (\mathrm{Class 1}) \\ \theta_{y_1} > \theta_{y_2} & (\mathrm{Class 2}) \end{cases}

であったのが

\begin{cases} \theta_{y_1} < \theta_{y_2} - m & (\mathrm{Class 1}) \\ \theta_{y_1} > \theta_{y_2} + m & (\mathrm{Class 2}) \end{cases}

になるということである(境界面がマージン m 分それぞれ離れている)。この効果をわかりやすく図示したものが以下の論文の図になる。


なお、似た先行研究として SphereFace、CosFace があるが、
SphereFace ではマージン調整を

s \cos \theta_{y_i} \to s \cos (m \theta_{y_i})

により行なっている。
CosFace ではマージン調整を

s \cos \theta_{y_i} \to s (\cos (\theta_{y_i}) + m)

により行なっている。

これだけ見ると全くもって微々たる違いにしか見えないが、下の論文中の図に分かりやすく示されている通り、ArcFace の場合においてのみ角度の空間において線形な分離境界になっている。論文ではこの差がまるで「バタフライエフェクト」のように訓練のパフォーマンス向上に大きな影響を与えると書いてある。


余談だが、ArcFace の Arc とは(おそらく)逆三角関数 arccos のことである。実装する際に、x_iW_j の内積から \cos \theta_j を求めた後、一旦 arccos で \theta_j を取り出し、m を足して \cos (\theta_j + m) に戻すというステップを踏む必要があるためそう名付けられたのだと思われる。

またここからは雑感なのだが、Metric Learning において SphereFace、CosFace、ArcFace のような分類ベースの損失関数がなぜ Open-Set 画像分類(つまりテストデータに未知のクラスがある分類)でも割と上手くいくのかは、よくわからない。訓練データのクラスの分離しか学習できないのでは?という感じもしてしまう(まあ、人間には4次元以上の超次元データの分離の様子など、そもそもよく分からないのだけど....)。


References

Discussion