👨‍🎨

Image Style Transfer "CAST" 実装と解説(調整中)

2023/02/21に公開

調整中記事 ディスカッション求む


Domain Enhanced Arbitrary Image Style Transfer via Contrastive Learning
Author: Yuxin Zhang, Fan Tang, Weiming Dong, Haibin Huang, Chongyang Ma Tong-Yee Lee
https://github.com/zyxElsa/CAST_pytorch
https://arxiv.org/abs/2205.09542
SIGGRAPH 2022, 2022 May 20

をざっくり読んで本家実装を弄ってみます。
一行でまとめると、リアル写真と画風参考用の絵画画像、そしてそれらを組み合わせて生成した絵画風リアル画像をGANの方法で作り出し、それらを対象学習で表現させることで良い性能を得られたという話です。

概要

画風変換(style transfer)において、対象学習を用いて複数のスタイル間の類似度を分布からダイレクトに学習する手法は、既存の2次統計量による手法よりも歪みやアーティファクトを軽減できて、良い画像を作れる。
この手法をContrastive Arbitrary Style Transfer (CAST)と呼ぶ。
CASTは次の3つの貢献があり、画風変換タスクにおいて定性的・定量的に良好な結果が得られた。

  • 画風のエンコーディングのためのmulti-layer style projector(MSP)と、画風の表現として2次統計量を用いず画風変換するエンコーダデコーダを使ったこと
  • 画風の分布を効率的に学習するためのdomain enhancement(DE)で画風の正例・不例を対象学習することにより、既存手法ができないような大量の画風情報を扱ったこと
  • 定量的評価を行い、CASTが既存手法よりも人間を騙せるような画像を生成できることを示したこと

以下画像は特別に表記しない限り元論文より引用

CASTによる画風変換

画風変換の概要

画風(art styles)は形、色、構図など視覚的な表現で、作品を表現する特徴とも言える。
画風変換は、自然画像(写真など)の内容と既存の絵画のスタイルを組み合わせて、新しい絵画を作成する方法として発達してきた。
[1]によりCNNとグラム行列による画風表現の学習方法が開発され、高度なDNNによる画風変換が大きく発展したが、グラム行列は分散や平均を表す2次統計量が頼りなため、これには発展の余地が残されている。
実際、絵のスタイルが異なると色や局所的な質感、レイアウト、構図まで変わるため、2次統計量を用いた学習では、色分布やレイアウト、筆使いを模倣することが難しい。
画風表現について再検討すると、全体的なスタイル記述子として2次統計量を用いることは、ソース(画像の変換元)とターゲット(変換したい画風)間の統計量に適合するように学習することになる。これは人工的に設計された特徴と損失関数であらかじめ定式化されていて、自然ではない。
CASTでは、ターゲットから直接画風の関係や分布を探る方法として対象学習による最適化を考察している。
ここで元論文では「芸術家でない人が1枚の絵を見せられても画風を見極めるのは困難だが、異なる画風の絵を与えられれば違いを識別することは簡単になる」という洞察から、対照学習に目をつけている。
CASTはエンコーダーデコーダーモデルのバックボーンと、画風の各特徴をエンコードして空間を構成するMSP、画風間の関係を対象学習で学習するDEによって構成されている。

モデル全体の学習の様子。GANのようにGeneratorを学習することで転写元と画風画像から画風変換後の画像を生成する。

モデルの全貌

CASTは画風の特徴をエンコードするMSP、MSPとGeneratorの学習を導く対象学習モジュール、画風の分布の学習を助ける強化スキームの3つから成り、入力された画像と生成された画像との差を測定する画風表現の学習に使用される。

特徴抽出器E(VGG-19)、PoolingベースのエンコーダPにより潜在表現を得ている。

MSP

NNへ局所的な筆使い(特徴)と全体的な外観を伝えるために、画風の判別と変換画像の生成を導くための画風表現を構成したい。CASTでは、複数レイヤー(特徴テンソルの行間)を跨ぐ特徴を用いるのではなく、MSPによって異なるレイヤーの特徴を別々の潜在空間に射影することで、グローバルとローカルの特徴のエンコードを行う。
事前学習したVGG-19を用いて特徴抽出したM個の縦横のテンソルを、さらにNNを用いてM×K次元の潜在表現(Style code)に変換する。MSPは各画像を処理するとき重みは共有している。
この潜在表現は2次統計量における平均と分散のように扱われ、画風変換モデルのガイダンスとして用いられる。

MSPは次のように動作する。M = 4のとき中間層出力の特徴量は4つ。

  1. 画像I_1と、それをオーグメントした画像I_2、そしてその他の関係ない画像I_oがある。
  2. VGGエンコーダEを用いて、特徴量(t_{1,1}, t_{1,2}, t_{1,3}, t_{1,4}) = E(I_1)(t_{2,1}, t_{2,2}, t_{2,3}, t_{2,4}) = E(I_2)(t_{o,1}, t_{o,2}, t_{o,3}, t_{o,4}) = E(I_o)とする。tのサイズは[(中間層のChannel数),h,w] * M個。これをデコーダの入力次元に変形する。
  3. NNデコーダPを用いて、潜在表現z_1 = P(t_1)z_2 = E(t_2)z_o = P(t_o)とする。zのサイズは[M,K]。

これにより得られた潜在表現を用いて、次の対照学習を行う。

対象学習

画風変換のためのガイドとなる潜在表現を学習したいが、画風に教師データが設計できないという問題を解決したい。ここで対照学習を使う。
対照学習では、画像I_1とその潜在表現z_1、それをオーグメントした画像I_2z_2、そしてその他の関係ない画像I_oz_oについて、z_1とその正例z_2、不例z_oを比較し、正例との相互情報量を最大化することで、画像の特徴を学習していく。
MSPとGeneratorの損失関数\texttt{L-contra-MPS}\texttt{L-contra-G}は次のように定義する。

MPSの損失関数\texttt{L-contra-MPS}
Generatorの損失関数\texttt{L-contra-G}

Unpaired image-to-image Translation(CUT)[2]のように画像を切り出したパッチ間の相互情報を最大化する学習ではなく、画像間の対比損失を計算している。

Domain Enhancement

モデルが潜在表現の分布を学習するため、敵対的損失を持つDomain Enhancement(DE)を導入する。
GANを用いた敵対的損失はデータセットの分布に強く依存するが、生成画像の全体的なスタイルを強化することができる。
CASTでは、学習データ空間の中からリアル領域(写真などの画風)とアート領域(絵画などの画風)を分けて、それらの空間に属するかどうかをDiscriminatorに識別させる学習を行う。

  1. リアル画像I_rを絵画の画像I_aの画風で変換したい。リアル領域のものか、アート領域のものかを判別するDiscriminators D_rD_aを用意する。(論文とは異なる文字を使用)
  2. Generator GによりI_rI_aをそれぞれ画風変換し、ドメインを交換したアート領域画像I_{r→a} = G(I_r,I_a)とリアル領域画像I_{a→r} = G(I_a,I_r)を生成する。
  3. I_{r→a}I_{a→r}はそれぞれD_aD_rのフェイクサンプルとして使う。

このDiscriminatorの損失\texttt{L-adv}は次の式で定義する。
敵対的生成で用いる損失関数\texttt{L-adv}

画風変換時に変換元の画像のコンテンツ情報を維持するため、サイクル一貫損失\texttt{L-cyc}を次の式で定義する。
サイクル一貫損失\texttt{L-cyc}

学習

GeneratorとDiscriminatorの学習では\texttt{L-adv}%, \texttt{L-cyc}, \texttt{L-contraG}の線形和を最終的な損失関数Contrastive style lossとして計算する。それぞれの係数はそれぞれ1, 2, 0.2が使われた。
学習データはアート領域画像をWikiArtから、リアル領域画像をPlaces365から、それぞれ256×256pxで20000枚サンプルしている。
VGGから得られるテンソルはM=4、潜在表現の次元はM層のそれぞれでK=512, 1024, 2048, 2048。
最適化はAdam(betas=(0.5, 0.999), lr=0.0001)、lr schedulerは線形減衰。batch sizeは4で行われた。

実験結果

CASTによる画風変換ではスタイルパターンと転写元の画像の構造の特徴を両方捉えられていて、既存手法のようなアーティファクトや歪みが出にくい。
これはグローバルなスタイル記述子を2次統計量で表すのではなく、MSPとDEによる学習で効果的に表現できるようになったからであると考えられている。

評価にはcontent loss, LPIPS, deception rateを用いる。
content loss[3], LPIPS[4]は事前学習済みモデル(VGG-19)を用いて、転写元画像と変換後の画像のAverage Perceptual Distancesを計算するもので、deception rateはWikiArt上の10種類の画風を分類するタスクで学習したVGG-19が変換後の画像を正しい画風の分類として予測する割合である.
このほかにも人間に対するアンケートでCASTと既存手法の画風変換どちらが良いか選択してもらう評価を行った結果、Sketch, Chinese painting, Impressionismの画風でCASTは高い評価を得ていたり(表のUser Study Iの欄)、Stylized Authenticity Detectionと呼ばれる、変換後の画像を混ぜた画像群の中からそれを見抜いてもらう方法を使って評価した結果、他手法より被験者をよく欺いた(表のUser Study IIの欄)。

これら指標を計算した結果は次の表であり、CASTが既存手法を総合的に上回っていることがわかる。
他手法との評価指標の比較

アブレーションスタディ

Contrastive style lossを取り除いて既存のグラム行列ベースの損失の手法に置き換えたり、Domain Enhancementにおいて識別機DI_rI_aで分けずに1つの識別機(mix-DE)で行ったり、\texttt{L-cyc}を消したりして結果を試している。
各手法を既存の方法で代替したときの生成結果

このようなCASTの構成要素を排除した画風変換ではクオリティが劣化していて、それぞれの提案手法が効いていることが示された。

CASTの実装

公式実装を読み、主要な部分の実装を見ていく。以下のコード部分は特記しない限り公式実装からの引用である。

MPSの実装

MPSのencoderはVGG-19をM層に分けた中間出力を得る。
encoderのVGG-19はM=4とするとき、64, 128, 256, 512のチャンネル数を持つ中間特徴量を出力する。

vgg-19の構成
# vgg-19(Shapeは入力画像244×244の場合)
==========================================================================================
Layer (type (var_name))                  Input Shape               Output Shape
==========================================================================================
Sequential (Sequential)                  [1, 3, 224, 224]          [1, 512, 14, 14]
├─Conv2d (0)                             [1, 3, 224, 224]          [1, 3, 224, 224]
├─ReflectionPad2d (1)                    [1, 3, 224, 224]          [1, 3, 226, 226]
├─Conv2d (2)                             [1, 3, 226, 226]          [1, 64, 224, 224]
├─ReLU (3)                               [1, 64, 224, 224]         [1, 64, 224, 224] -> enc_1
├─ReflectionPad2d (4)                    [1, 64, 224, 224]         [1, 64, 226, 226]
├─Conv2d (5)                             [1, 64, 226, 226]         [1, 64, 224, 224]
├─ReLU (6)                               [1, 64, 224, 224]         [1, 64, 224, 224]
├─MaxPool2d (7)                          [1, 64, 224, 224]         [1, 64, 112, 112]
├─ReflectionPad2d (8)                    [1, 64, 112, 112]         [1, 64, 114, 114] 
├─Conv2d (9)                             [1, 64, 114, 114]         [1, 128, 112, 112]
├─ReLU (10)                              [1, 128, 112, 112]        [1, 128, 112, 112] -> enc_2
├─ReflectionPad2d (11)                   [1, 128, 112, 112]        [1, 128, 114, 114]
├─Conv2d (12)                            [1, 128, 114, 114]        [1, 128, 112, 112]
├─ReLU (13)                              [1, 128, 112, 112]        [1, 128, 112, 112]
├─MaxPool2d (14)                         [1, 128, 112, 112]        [1, 128, 56, 56]
├─ReflectionPad2d (15)                   [1, 128, 56, 56]          [1, 128, 58, 58]
├─Conv2d (16)                            [1, 128, 58, 58]          [1, 256, 56, 56]
├─ReLU (17)                              [1, 256, 56, 56]          [1, 256, 56, 56] -> enc_3
├─ReflectionPad2d (18)                   [1, 256, 56, 56]          [1, 256, 58, 58]
├─Conv2d (19)                            [1, 256, 58, 58]          [1, 256, 56, 56]
├─ReLU (20)                              [1, 256, 56, 56]          [1, 256, 56, 56]
├─ReflectionPad2d (21)                   [1, 256, 56, 56]          [1, 256, 58, 58]
├─Conv2d (22)                            [1, 256, 58, 58]          [1, 256, 56, 56]
├─ReLU (23)                              [1, 256, 56, 56]          [1, 256, 56, 56]
├─ReflectionPad2d (24)                   [1, 256, 56, 56]          [1, 256, 58, 58]
├─Conv2d (25)                            [1, 256, 58, 58]          [1, 256, 56, 56]
├─ReLU (26)                              [1, 256, 56, 56]          [1, 256, 56, 56]
├─MaxPool2d (27)                         [1, 256, 56, 56]          [1, 256, 28, 28]
├─ReflectionPad2d (28)                   [1, 256, 28, 28]          [1, 256, 30, 30]
├─Conv2d (29)                            [1, 256, 30, 30]          [1, 512, 28, 28]
├─ReLU (30)                              [1, 512, 28, 28]          [1, 512, 28, 28] -> enc_4
├─ReflectionPad2d (31)                   [1, 512, 28, 28]          [1, 512, 30, 30]
〜
├─ReLU (52)                              [1, 512, 14, 14]          [1, 512, 14, 14] 

エンコーダとして用いるVGGは勾配から切り離し学習させない。
ADAIN_Encoderのenc_layers部分でVGGが切り離されており、論文中の複数レイヤーとはこれのことを指す。

net.py
# MPSのエンコーダ

class ADAIN_Encoder(nn.Module):
    def __init__(self, encoder, gpu_ids=[]):
        super(ADAIN_Encoder, self).__init__()
        enc_layers = list(encoder.children())
        self.enc_1 = nn.Sequential(*enc_layers[:4])     # input   -> relu1_1 64
        self.enc_2 = nn.Sequential(*enc_layers[4:11])   # relu1_1 -> relu2_1 128
        self.enc_3 = nn.Sequential(*enc_layers[11:18])  # relu2_1 -> relu3_1 256
        self.enc_4 = nn.Sequential(*enc_layers[18:31])  # relu3_1 -> relu4_1 512
        
        self.mse_loss = nn.MSELoss()

        # fix the encoder
        for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
            for param in getattr(self, name).parameters():
                param.requires_grad = False

    # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
    def encode_with_intermediate(self, input):
        results = [input]
        for i in range(4):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
        return results[1:]

    def calc_mean_std(self, feat, eps=1e-5):
        # eps is a small value added to the variance to avoid divide-by-zero.
        size = feat.size()
        assert (len(size) == 4)
        N, C = size[:2]
        feat_var = feat.view(N, C, -1).var(dim=2) + eps
        feat_std = feat_var.sqrt().view(N, C, 1, 1)
        feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
        return feat_mean, feat_std

    def adain(self, content_feat, style_feat):
        assert (content_feat.size()[:2] == style_feat.size()[:2])
        size = content_feat.size()
        style_mean, style_std = self.calc_mean_std(style_feat)
        content_mean, content_std = self.calc_mean_std(content_feat)

        normalized_feat = (content_feat - content_mean.expand(
            size)) / content_std.expand(size)
        return normalized_feat * style_std.expand(size) + style_mean.expand(size)

    def forward(self, content, style, encoded_only = False):
        style_feats = self.encode_with_intermediate(style)
        content_feats = self.encode_with_intermediate(content)
        if encoded_only:
            return content_feats[-1], style_feats[-1]
        else:
            adain_feat = self.adain(content_feats[-1], style_feats[-1])
            return  adain_feat

// TODO 続きの実装読み次第書きます

感想

画像間の特徴をうまくNNに埋め込みたいというタスクを対照学習によって解決するというのは自然で、実際二次統計量ベースのものより良く変換できているように見える。
図に示されている画風変換はとてもクオリティが高く見えるのだが、WikiArtに無い画風(例えばアニメ調の現代的なイラスト)や自然画像以外の転写元に対する外挿的な汎化性能については気になるところである。
RELATED WORKにもあるが、雰囲気としてはGANベースの手法に対照学習を取り入れたシンプルなアイディアに見えるので、あまり周辺について勉強してなくても実装できそうに感じる。

参照

[1] Leon A Gatys, Alexander S Ecker, and Matthias Bethge. "Image style transfer using convolutional neural networks." CVPR 2016. https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf

[2] Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros. "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks", ICCV 2017. https://arxiv.org/abs/1703.10593

[3] Yijun Li, Chen Fang, Jimei Yang, Zhaowen Wang, Xin Lu, Ming-Hsuan Yang. 2017. "Universal style transfer via feature transforms.", In Advances Neural Information
Processing Systems (NeurIPS). 386–396.

[4] Haibo Chen, Lei Zhao, Zhizhong Wang, Zhang Hui Ming, Zhiwen Zuo, Ailin Li, Wei Xing, Dongming Lu. 2021a. "Artistic Style Transfer with Internal-external Learning and Contrastive Learning.", In Advances in Neural Information Processing Systems (NeurIPS).

[5] Artsiom Sanakoyeu, Dmytro Kotovenko, Sabine Lang, Björn Ommer. 2018b. "A Style-Aware Content Loss for Real-Time HD Style Transfer.", In European Conference Computer Vision (ECCV). Springer International Publishing, Cham, 715–731.

Discussion