🤓

[Diffusion Model] Hypernetworksのレイヤー構造を変えた際の変化を比較する

2022/10/19に公開約9,400字

はじめに

はじめまして、なんかと申します。

少し前に、NovelAIから新しい学習済みLatent Diffsion Modelのチューニング手法としてHypernetworksが提案されました。

しかし、提案とは言っても「ウチではこういうのやってるよ」程度のモノであり、まだまだどのようなレイヤー構造が良いのか、どんなハイパーパラメータが適しているのか、といった探索はほとんど為されていません。

そこで今回は、Hypernetworksについて少し説明を行ってから、使用するモデルやハイパーパラメータを固定し、Hypernetworksのレイヤー構造を変えた場合に結果がどう変わるかについて簡単な比較・検証を行っていきます。

[追記: 10/21 12:52]
この記事で使用しているHypernetworks全てに活性化関数が入っていないため、この記事の比較・検証にはあまり意味がありません。
本日中に活性化関数を追加して記事を更新します。

(Kerasばっかり触ってるのでPyTorchのAPI何も分かってません)

[追記: 10/21 18:54]
活性化関数の有無を比較対象に追加し、記事を大幅に更新しました。

Hypernetworksについて

この手法についてご存知でない方のために、まずは提案が為されたブログより、説明を一部引用しつつ簡単に解説します。

It should be noted that this concept is entirely disparate from the HyperNetworks introduced by Ha et al in 2016, which work by modifying or generating the weights of the model, while our Hypernetworks apply a single small neural network (either a linear layer or multi-layer perceptron) at multiple points within the larger network, modifying the hidden states.

最初に注意すべきこととして、非常に面倒なことに2016年のHa et alによる提案手法と名前が重複していますが、全く異なるものです。

NovelAIの提案するHypernetworksは、既存の大きなDNNの間に小さなNNを複数追加し、それを追加のデータセットで学習させることで、学習済みモデルのパラメータを変更することなくhidden statesを弄る、という手法です。

ここで問題となってくるのが、では既存の学習済みNN(ここではStable Diffusionとして話を進めます)のどの部分にHypernetworksを適用すると良いのかという疑問で、極端な話、Hypernetworksを全てのレイヤーの後に追加すると最も正確に追加データセットの特徴を学習してくれるかもしれませんが、それでは学習推論双方において計算量が膨大なものになってしまうため、チューニング手法として適切ではありません。

そのため、できる限り少ないパラメータ量で、大きな改善が見込めるポイントにのみHypernetworksを適用することが重要となってきます。

そこで、NovelAIが試行錯誤の末辿り着いたのが、U-Netのクロスアテンション機構のうちkeyとvalueにのみHypernetworksを適用することでした。

After many iterations testing many different architectures, Aero was able to come up with one that is both performant and achieves high accuracy with varied dataset sizes. The hypernets are applied to the k and v vectors of CrossAttention layers in StableDiffusion, while not touching any other parts of the U-net.

クロスアテンション機構の説明については割愛します。
個人的な理解ですが、学習済みのネットワークが提示したqueryに対して与えられるkeyとvalueがHypernetworksによって変更されることで、画像であれば任意のピクセルと紐つくピクセルが追加データセットそれぞれのスタイルによって変わるため、クロスアテンション機構に対して直に絵柄を教え込んでいるようなものだと考えています。

検証について

さて、本題に入ります。
NovelAIが提案する、U-Netのクロスアテンション機構内部のkey, valueをHypernetworksによって変換する手法(長すぎる、以下単純にHypernetworksと呼びます)をとる際、以下の課題があることがNovelAI自身により指摘されています。

We found that the shallow attention layers overfit quickly with this approach, so we penalize those layers during training. This mostly mitigated the overfitting issue and results in better generalization at the end of training.

この文章は(おそらく意図的に)非常に曖昧なものになっているため、解釈のやりようが複数ある気がしますが、重要なことは適当に浅いNNを作ると過学習しやすいことです。

さっそく検証してみましょう。
みんな大好きAUTOMATIC1111さんによるstable-diffusion-webuiには既にHypernetworksが組み込まれており、ユーザーが用意したデータによって学習・推論が行えるようになっているので、それを利用します。

比較内容

Hypernetworksのレイヤーを増やす/増やさないに加え、Layer Normalization(以下LN)を入れる/入れないを切り替えてそれぞれ学習させます。

[追記 10/21 16:47]上記に加え、活性化関数(ReLU)の有無を比較に加えます。

そして、以下の結果を定性評価し、どのようなレイヤー構造だと過学習に陥りやすい/にくいのかについて確認します。

結果1. txt2imgで、プロンプト・ハイパーパラメータ・シードを揃えて1枚画像を生成
結果2. txt2imgで、プロンプト・ハイパーパラメータを揃え、ランダムシードで6枚画像を生成

また、Hypernetworksのレイヤー構造は(1, 2, 1)のように表記します。
それぞれの数字はレイヤー間の特徴量の次元数を示し、例えば、(1, 2, 1)はHypernetworksの1つが以下のような構造であることを意味します。

Sequential(
  (0): Linear(in_features=1280, out_features=2560, bias=True)
  (1): ReLU()
  (2): Linear(in_features=2560, out_features=1280, bias=True)
  (3): ReLU()
)

比較環境

以下の条件を揃えた上で、Hypernetworksのレイヤー構造のみを変更して学習を行います。

学習済みモデル

Waifu Diffusion v1-3 float16

学習用データ

  • 画像

    上の学習済みモデルから生成した画像8枚をFlipして16枚にしたもの、サイズは576*576
    どの画像も目が大きく、今時(?)な絵柄で、キラキラしまくっているのが特徴です。

  • テキスト
    各画像をBLIPでInterrogateしたもの

パイパーパラメーター

  • Steps: 5000
  • Learning Rate: 5e-6 ~500Steps, 3e-06 ~1000Steps, 2e-06 ~5000Steps
  • Batch Size: 1
  • Optimizer: AdamW
  • Loss: LDM Loss

比較結果

0. Hypernetworksなし - ベースライン

  • 結果1

  • 結果2

1. (1, 2, 1), LNなし, Linear

Sequential(
  (0): Linear(in_features=1280, out_features=2560, bias=True)
  (1): Linear(in_features=2560, out_features=1280, bias=True)
)
  • 学習過程

  • 結果1

  • 結果2

学習データを参考にして、目をキラキラさせたり、アニメ的な画風には持っていくことは出来ていますが、絵柄(特に顔の感じ)はそこまで近くなっていません。
やはり、単純な線形のネットワークだとうまく特徴を学習できないようです。

2. (1, 2, 1), LNなし, ReLU

Sequential(
  (0): Linear(in_features=1280, out_features=2560, bias=True)
  (1): ReLU()
  (2): Linear(in_features=2560, out_features=1280, bias=True)
  (3): ReLU()
)
  • 学習過程

  • 結果1

  • 結果2

上の結果よりは学習データの雰囲気に近づきましたが、まだうまく学習できていないのか、出力の傾向が一昔前の絵柄(?)になりがちに見えます。

3. (1, 2, 1), LNあり, ReLU

Sequential(
  (0): Linear(in_features=1280, out_features=2560, bias=True)
  (1): ReLU()
  (2): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
  (3): Linear(in_features=2560, out_features=1280, bias=True)
  (4): ReLU()
  (5): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
)
  • 学習過程

  • 結果1

  • 結果2

Layer Normalizationを入れることで、学習データと絵柄が近い出力がグッと増えました。浅いレイヤー構造なのは上の2つと同様ですが、こちらはうまく学習できているのではないでしょうか?
一般に、Layer Normalizationには収束を早める効果があるため、本来5000Stepsでは収束しなかったはずの学習がLNによって上手くいったのかもしれません。

4. (1, 2, 2, 1), LNなし, Linear

Sequential(
  (0): Linear(in_features=1280, out_features=2560, bias=True)
  (1): Linear(in_features=2560, out_features=2560, bias=True)
  (2): Linear(in_features=2560, out_features=1280, bias=True)
)
  • 学習過程(3000stepsで収束)

  • 結果1

  • 結果2

活性化関数が無いことにはいくらレイヤーを増やしても線形変換なのは変わらないので、(1, 2, 1)の活性化関数なしの際と同様の結果が出てくるのでは?と思いましたが、見た感じ活性化関数ありには及ばずとも、それなりに学習が上手くいっているように見えます。
NNなんも分からん…。

5. (1, 2, 2, 1), LNなし, ReLU

Sequential(
  (0): Linear(in_features=1280, out_features=2560, bias=True)
  (1): ReLU()
  (2): Linear(in_features=2560, out_features=2560, bias=True)
  (3): ReLU()
  (4): Linear(in_features=2560, out_features=1280, bias=True)
  (5): ReLU()
)
  • 学習過程

  • 結果1

  • 結果2

出力をチェリーピッキングしていないため、結果2の最初の2枚が微妙なのはランダムシードの問題だとして、なかなか上手に学習できているのではないでしょうか?
しかし、学習に使用しているデータセットが非常に小さいからだとは思いますが、レイヤー構造が(1, 2, 1)でLNありかつReLUを入れた場合とそこまで有意な差があるとは思えません。
LNを追加するより全結合層を追加した方が学習・推論にかかる時間が伸びるため、データセットのサイズによってはLNを追加するだけの方がコストパフォーマンス的に良さそうです。

6. (1, 2, 2, 1), LNあり, ReLU

Sequential(
  (0): Linear(in_features=1280, out_features=2560, bias=True)
  (1): ReLU()
  (2): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
  (3): Linear(in_features=2560, out_features=2560, bias=True)
  (4): ReLU()
  (5): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
  (6): Linear(in_features=2560, out_features=1280, bias=True)
  (7): ReLU()
  (8): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
)
  • 学習過程

  • 結果1

  • 結果2

これも上手く学習できていると思いますが、やはり上で言ったことと同じく、(1, 2, 1)、LNあり、ReLUの場合との違いが誤差レベルに思えます。

おわりに

Hypernetworksのレイヤー構造を変更して結果を比較したことで、以下のような示唆 or Tipsが得られました。

  • (1, 2, 1)のような浅いレイヤー構造でも、適切な活性化関数の導入によって学習を上手くやることができる
    • しかし、データセットが大きくなるとそうとは限らない
  • 活性化関数, Layer Normalizationは入れておいて基本損せず、特に活性化関数が無い場合は学習が上手くいかないかもしれない
  • [2022/10/30 追記]: Layer Normalizationが学習を不安定にさせることがあるようです。また、汎化性能より学習した画像のスタイルの忠実な再現を目指す場合、むしろ過学習させたほうがいいかもしれないので、その場合は活性化関数なし、Layer Normalizationなしがむしろ良く働くかもしれません。

こうして見ると、一般的なDNNのチューニングで言われている経験的なTipsとあまり変わらないような気がしなくもないですね…。

もしやる気が続けば、データセットを拡張したり、学習率・ステップ数のようなハイパーパラメータを変更しての検証も行いたいと思います。

Discussion

ログインするとコメントできます