Domain Adversarial Neural Networksについて(前編)
はじめに
先日参加したKaggleのOpenVaccine: COVID-19 mRNA Vaccine Degradation Predictionコンペティションで自分が参加していたチームではDomain Adversarial Neural Networks (DANN)と呼ばれる手法を用いていました。
結果としては、CV, Public LB, Private LBのいずれにも効いていないことが判明したのですが、Kaggleで度々話題になるAdversarial Validationとも類似した面白い技術なので、改めて紹介するとともに本当に使える手法なのかを検証していきたいと思います。
本記事は二部構成(三部構成、2020/10/25更新)になっており、前編(この記事)ではDANNの紹介と、論文中でも紹介されているMNIST/MNISTMを用いて検証を行います。後編つづく中・後編では実際のKaggleのコンペティションのデータを用いてDANNがKaggleでも使えそうな技術なのか、実験をして検証します。
Domain Shiftについて少し
DANNの説明に入る前に、Domain Shiftについて少し触れておきましょう。機械学習の研究・検証の多くの場面では暗黙の前提として訓練データとテストデータ(運用データ)は同じ分布から採取されているという仮定をおいています。
一方で、実際に機械学習のアルゴリズムが運用される多くの場面において、テストデータは学習データと異なる分布から生成されているという現実があります。この記事を読んでくださっている読者の方の中には、実際に経験したことがあるという方もいるかと思います。
この分布のずれをDomain Shiftと呼ぶのですが、このような状況では検証用のデータセットでは良好な成績を誇っていたモデルがいざ運用に出してみるとあまり性能を発揮できていない、ということがよく起こります。Kaggleでもこのような設定になっているコンペティションは数多くあり多くのKagglerを悩ませてきたというのは次のmemeからも見て取れます。
Domain Adversarial Neural Networks
さて、DANNはDomain Shiftが起こっている状況で、Target[1]データ自体は手元にありドメインラベルが得られる状態[2]で適用できるニューラルネットワークを用いた表現学習の手法です。ニューラルネットワークを特徴抽出器と分類器(回帰器)が連結したものだと見たうえで
- Source側のラベルで教師ありタスクを解きつつ
- 教師ありタスクを解くのに使うのと同じ特徴を受け取ってドメインラベルを分類するドメイン識別問題を解き
- ドメイン分類器と特徴抽出器の間に勾配を逆転させる層を一枚挟むことで特徴抽出器がドメインを識別できない方向に学習が進むように勾配を逆伝播させる
というアイデアになります。
Domain Shiftのなかでも、入出力規則はSourceとTargetで変わらないが入力の分布が変化してしまうCovariate Shift(共変量シフト) と呼ばれる状況で有効とされ、近年のDomain Adaptation
KagglerからみたDANN
一見すると、Kaggle、特にcsvを提出するタイプの古いコンペティションでは非常に有効そうな手法ですが、驚いたことにKaggle界隈で見かけることはほとんどありません。
一方で、ドメイン識別器を作るというアイデアはたびたび話題に上がるAdversarial Validationとも類似しており、Adversarial Validationではドメイン識別の結果をデータ理解に役立てるなどする一方、DANNではその情報をそのまま表現にフィードバックしてしまってドメイン不変表現を得る、という点でDANNの方が幾分直接的で思い切りのいい手法のように思えます。
さて、一見良さそうな技術だがKagglerの間で使われていない技術がある、というときに二つの可能性があります。
- 単純にKagglerが知らない・概念がimportされていない
- あんまりKaggleで使えない技術である・公称の性能に疑義がある
Kaggleではコンペの中で精度を追求するという性質上、使える技術であれば数ヶ月前にarXivに公開された手法が実戦投入されるということも比較的よく起こります。研究と実運用の中間にある試験場のような場所となっている側面もあるため、2の可能性も考えられるということです。
MNIST vs MNISTMで検証
それでは、論文で紹介されていたSourceがMNIST、TargetがMNISTMの例で試してみましょう。MNISTは皆さんご存知の手書きの数字のデータセットですが、MNISTMはMNISTの背景と文字の中を適当な画像などで入れ替えてしまったデータセットになり、見た目はMNISTとかなり異なります。
なかには人間でも判別がしづらいものもあり、MNISTで学習されたモデルをそのままMNISTMに適用してもうまくいかなそうなことは直感的に理解できるかと思います。実際、論文中ではその設定ではTargetのAccuracyが0.5225
になってしまうと報告されています。
実装
実装はhttps://github.com/koukyo1994/domain-adversarial-nn で公開しています。ここでは重要な部分を抜粋して紹介します。
まず肝心のDANNの実装ですが、
class DomainAdversarialCNN(nn.Module):
def __init__(self, img_size=32):
super().__init__()
self.img_size = img_size
self.feature_extractor = nn.Sequential(
(略)
)
in_features = self._get_in_features()
self.classifier = nn.Sequential(
nn.Linear(in_features, 100),
(中略)
nn.Linear(100, 10)
)
self.domain_classifier = nn.Sequential(
nn.Linear(in_features, 100),
nn.BatchNorm1d(100),
nn.ReLU(),
nn.Linear(100, 1)
)
def _get_in_features(self):
(略)
def forward(self, x, alpha):
batch_size = x.size(0)
x = self.feature_extractor(x).view(batch_size, -1)
y = GradientReversalLayer.apply(x, alpha)
x = self.classifier(x)
y = self.domain_classifier(y)
return {
"logits": x.view(batch_size, -1),
"domain_logits": y.view(batch_size, -1)
}
このようになっています。classifier
もdomain_classifier
も特に捻りはありませんが、唯一特徴的な点としてはforward
メソッドの中で用いているGradientReversalLayer
でしょう。これは、domain_classifier
側の勾配をfeature_extractor
に伝播する際に符号を反転する機能を持っており次のような実装になっています。
from torch.autograd import Function
class GradientReversalLayer(Function):
@staticmethod
def forward(context, x, constant):
context.constant = constant
return x.view_as(x) * constant
@staticmethod
def backward(context, grad):
return grad.neg() * context.constant, None
たったこれだけです。
一方、学習の際にはTarget側のデータも用います。Domain Classificationと主タスクの学習を交互に行うか、同時に行うかという点で選択の余地がありますが、私の実装では同じDatasetに入れてしまって(データ数のバランスをとって)、ドメインラベルが1
(mnistm)のサンプルについては主タスクのロスを計算しないようにして学習を進めるようにしました。
また、元論文との違いとして(2020/10/25更新、これは誤りで元の論文で用いられているものでした。)domain_classifier
と主タスクの分類器のバランスをとる係数(実装の中ではalpha
)をexponential warm-upするようにしました。これは参考にしたレポジトリのなかで使われていたテクニックで、Target側での主タスクの性能がよくなる効果があるようです。
total_steps = self.num_epochs * self.loader_len
p = float(self.loader_batch_step +
(self.epoch - 1) * self.loader_len) / total_steps
alpha = 2.0 / (1.0 + np.exp(-10 * p)) - 1
結果
まず、ドメイン分類のAUCですが以下のようになりました。
ある程度のブレはあるものの、0.6
を中心に0.5 - 0.7
の範囲に収まっており、"SourceとTargetを分類できなくする"という目論見はうまくいっているようです。
続いて、Source側のAccuracyの推移ですが以下のようになりました。
50epochかけて0.991あたりに収束しているように見えます。Sourceのみの学習だと0.995くらい、とても頑張ると0.998ぐらいなので若干下がってしまっています。
さて、本命のTarget側のAccuracyです。
なんと0.9超えを達成しています!改めて、Target側のラベルは存在はしますが、今回の学習には一切使用していないことを特筆しておきます。公称ではMNIST/MNISTMはDANNを使うことで0.5225
から0.7666
に上がるとのことですが、alpha
をうまくいじってあげることで0.9
くらいまでは上げられることが分かりました。
DANNなしの時と比較
さてこうなってくると気になるのが、DANNなしの時はどうなのか、という話です。そこでドメイン分類器のブランチを無くしてドメイン分類器側からのフィードバックをしないようにして実験を行ってみました。つまり、なんの工夫もない状態です。
まず、Source側のAccuracyです。
0.995あたりを彷徨っていることが分かります。やはり、DANNを使った時より若干高い精度を達成しているようです。
一方、本命のTarget側のAccuracyは
0.5台を彷徨っているようで、これは論文に載っていた数字とも概ね一致します。やはり、なんの工夫もない状態では難しいようです。
DANNは「ドメインが見分けられなくなるような特徴表現を学習する」手法ですので、本当にそのような効果があるのかみてみましょう。DANNありの時とDANNなしの時でfeature_extractor
にデータを入力して特徴表現を得て、それをUMAPで二次元に圧縮して表示してみました。
まずDANNなしの時は、
MNISTMが中央にギュッと集まる一方でMNISTはその周辺にいくつかクラスタを作っていることが分かります。
クラスの分布と重ねて見ると、MNISTはdigitごとのクラスタに分けられている一方でMNISTMはクラスが重なり合ってしまっているようです。
一方、DANNありの場合、
MNISTとMNISTMの重なりはよくなっているようです。一方でいまだにMNISTMのみの領域も多く、ドメイン分類器のAUCが0.5だったことは若干怪しさを感じます。特徴表現へのフィードバックができているのは確かなようですが、そこから伸びたドメイン分類器のブランチがあまりうまく学習できていないのかもしれません。
一方でクラスの分布も併せて見ると、Target側でAccuracyが0.9だったのもうなずけます。特徴の分布がSourceとTargetで同じになることはないようですが、少なくとも近づけることでTarget側でもクラス間の分離が進んでいることが分かります。
まとめ
さて、長くなってしまいましたが、前編はここまでです。この記事では
- Domain Adversarial Neural Networksの紹介
- MNIST vs MNISTMの実験で検証
の2点を行いました
結果としては論文で紹介されている以上の性能を見ることができましたが、SourceとTargetの特徴を近づける、という点ではまだ不十分ではないかという点は要検証でしょう。次回はこの点についての考察を深めるとともに、実際のKaggleの過去のコンペティションのデータを用いてDANNを検証してみたいと思います。
-
Domain Adaptationの研究ではSourceとTargetの間にDomain Shiftがあるとします。Source側にはなんらかの教師あり学習タスクのラベルがあり、Target側にはそのタスクのラベルはない、という状況でTarget側の予測性能などをできる限り高めよう、という研究領域をUnsupervised Domain Adaptationと呼び、盛んに研究がなされています。 ↩︎
-
実運用上、このような状況 - テストデータのラベルはないがドメインラベル付きの状態でデータ自体はある状況 - は多くないという批判もあります。一方で昔はKaggleのコンペ(CSVを提出する方式の時代)は基本この設定が成り立っていました。最近はコードコンペが増えてテストデータが隠されていることが増えているのでそういう時はDANNは使えませんね。 ↩︎
-
あんま適当なことをいうと怖い人に怒られそうなのでお茶を濁しておきます。近年の研究動向などは、https://www.slideshare.net/DeepLearningJP2016/dl-149549473 などに詳しく説明されています。 ↩︎
Discussion