🖼️

シンプルなDiffusion Transformerを実装してみる

2024/12/17に公開

最近はTransformerベースのDiffusionモデル、いわゆるDiT(Diffusion Transformer)がUNetベースのDiffusionモデルよりも性能がいいと聞くもののコードレベルで理解ができなったのでMNISTの手書き数字の学習と生成を題材にシンプルなDiTを実装してみる。

(この記事はLLM・LLM活用 Advent Calendar 2024の16日目の記事です!
テーマは画像生成ですが、
GPTなどLLMで広く使われているTransformerを画像生成に応用したという意味で許してください🙏)

出力比較

UNetベースのシンプルなDiffusionモデル

ソースコード

TransformerベースのシンプルなDiffusionモデル(DiT)

ソースコード

10epochずつ学習させてみたが精度が出る学習率も違うし、どの段階で比較すればいいかがわからなかった。
とりあえずUNetベースと同じくらいの性能は出せるようになってよかった。

普段モデルの開発とかやらないにも関わらず、チューニングの手順やモデルの性能の評価方法などインプットせずに進めて大変効率が悪かったので、次回はその辺りをある程度体系的にインプットした上でモデルの改善に取り組みたい。

UNetベースの方はゼロから作るDeepLearningの5で作ったやつをそのまま流用していて、DiTはFacebookの実装を参考にさせていただいた。

もうちょっと詳細に解説を書きたかったけれど、間に合わなかったので以下簡略的に類似点と相違点をそれぞれo1に解説させてみる。

共通点

  • ディフュージョンプロセスの利用:どちらのモデルも、ノイズを段階的に加えていき、そのノイズを取り除く過程で元の画像を生成します。
  • ノイズの追加と除去のアルゴリズムDiffuserCondクラスを使用して、ノイズを画像に追加し、逆に除去する手順を実装しています。

違い

主な違いは、ノイズを予測するためのモデルの構造です。

  • UNetベースモデル:畳み込みニューラルネットワーク(CNN)であるUNetを使用しています。
  • Transformerベースモデル:自己注意機構を持つTransformerアーキテクチャを使用しています。

これから、それぞれのモデルのコードを比較しながら、その違いを詳しく見ていきましょう。


1. モデルの定義

UNetベースモデル

UNetは画像処理に特化したネットワークで、エンコーダとデコーダから構成されます。

conditional.py
class UNetCond(nn.Module):
    def __init__(self, in_ch=1, time_embed_dim=100, num_labels=None):
        super().__init__()
        # ダウンサンプリング(エンコーダ)部分
        self.down1 = ConvBlock(in_ch, 64, time_embed_dim)
        self.down2 = ConvBlock(64, 128, time_embed_dim)
        
        # ボトルネック部分
        self.bot1 = ConvBlock(128, 256, time_embed_dim)
        
        # アップサンプリング(デコーダ)部分
        self.up2 = ConvBlock(128 + 256, 128, time_embed_dim)
        self.up1 = ConvBlock(128 + 64, 64, time_embed_dim)
        
        # 出力部分
        self.out = nn.Conv2d(64, in_ch, 1)
        
        # プーリングとアップサンプリング
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear")
        
        # ラベル埋め込み(条件付きの場合)
        if num_labels is not None:
            self.label_embed = nn.Embedding(num_labels, time_embed_dim)

    def forward(self, x, timesteps, labels=None):
        # 実装詳細は省略
        pass

Transformerベースモデル

Transformerは自己注意機構を使って、データ内の関係性を学習します。

dit.py
class DiT(nn.Module):
    def __init__(
        self,
        input_size=28,
        patch_size=4,
        in_channels=1,
        hidden_size=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4.0,
        num_classes=1,
        learn_sigma=False,
        class_dropout_prob=0.1,
    ):
        super().__init__()
        # パッチ埋め込み層:画像を小さなパッチに分割し、埋め込みベクトルに変換
        self.x_embedder = PatchEmbed(
            input_size, patch_size, in_channels, hidden_size, bias=True
        )
        
        # 時間埋め込み層
        self.t_embedder = TimestepEmbedder(hidden_size)
        
        # ラベル埋め込み層(クラスドロップアウト付き)
        self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
        
        # ポジションエンコーディング
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.x_embedder.num_patches, hidden_size), requires_grad=False
        )
        
        # Transformerブロックのスタック
        self.blocks = nn.ModuleList(
            [
                DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
                for _ in range(depth)
            ]
        )
        
        # 最終出力層
        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
        
        # 重みの初期化
        self.initialize_weights()

    def forward(self, x, t, y):
        # 実装詳細は省略
        pass

2. 入力データの処理方法

UNetベースモデル

UNetモデルは、畳み込み層を使って画像全体を処理します。

  • 特徴
    • 畳み込み層は画像の局所的なパターン(例えばエッジや角など)を捉えるのに優れています。
    • 入力画像をそのままネットワークに渡します。

Transformerベースモデル

Transformerモデルは、パッチ埋め込みを使って画像を小さなブロックに分割して処理します。

  • 特徴
    • 画像をいくつかのパッチ(小さな矩形ブロック)に分割します。
    • 各パッチを1次元のベクトル(埋め込み)に変換します。
    • パッチ間の関係性を自己注意機構で学習します。

コードの違い

conditional.py
# UNetでは、画像全体をそのまま処理
def forward(self, x, timesteps, labels=None):
    x1 = self.down1(x, v)
    # 以下略
dit.py
# DiTでは、画像をパッチに分割してから処理
def forward(self, x, t, y):
    x = self.x_embedder(x) + self.pos_embed  # パッチ埋め込みとポジションエンコーディングの加算
    # 以下略

3. 時間とラベルの埋め込み方法

共通点

  • 時間埋め込み(タイムステップ情報):ノイズの量を示すタイムステップ情報を埋め込みベクトルに変換してモデルに入力します。
  • ラベル埋め込み:条件付きモデルの場合、クラスラベルを埋め込みベクトルに変換してモデルに入力します。

違い

  • UNetベースモデル
    • 時間埋め込みベクトルを各層に渡します。
    • ラベル埋め込みはオプションで使用します。
conditional.py
def forward(self, x, timesteps, labels=None):
    v = pos_encoding(timesteps, self.time_embed_dim, x.device)
    if labels is not None:
        v = v + self.label_embed(labels)
    # vを各層に渡す
  • Transformerベースモデル
    • 時間埋め込みとラベル埋め込みを足し合わせて、新たなコンテキストベクトルcを作成します。
    • cは各Transformerブロックで条件付けに使用されます。
dit.py
def forward(self, x, t, y):
    x = self.x_embedder(x) + self.pos_embed
    t = self.t_embedder(t)
    y = self.y_embedder(y, self.training)
    c = t + y
    for block in self.blocks:
        x = block(x, c)
    # 以下略

4. モデル内部の構造

UNetベースモデル

  • 特徴
    • エンコーダ部分で画像を小さく(情報を圧縮)し、デコーダ部分で元のサイズに戻します。
    • エンコーダとデコーダの対応する層で「スキップ接続」を使い、詳細な情報を保持します。
conditional.py
def forward(self, x, timesteps, labels=None):
    # ダウンサンプリング
    x1 = self.down1(x, v)
    x = self.maxpool(x1)
    x2 = self.down2(x, v)
    x = self.maxpool(x2)
    
    # ボトルネック
    x = self.bot1(x, v)
    
    # アップサンプリング
    x = self.upsample(x)
    x = torch.cat([x, x2], dim=1)  # スキップ接続
    x = self.up2(x, v)
    x = self.upsample(x)
    x = torch.cat([x, x1], dim=1)  # スキップ接続
    x = self.up1(x, v)
    x = self.out(x)
    return x

Transformerベースモデル

  • 特徴
    • パッチ埋め込みされたデータに対して、複数のTransformerブロックを通します。
    • 各ブロックで自己注意機構を使って、パッチ間の関係性を学習します。
dit.py
def forward(self, x, t, y):
    x = self.x_embedder(x) + self.pos_embed
    t = self.t_embedder(t)
    y = self.y_embedder(y, self.training)
    c = t + y
    for block in self.blocks:
        x = block(x, c)
    x = self.final_layer(x, c)
    x = self.unpatchify(x)
    return x

5. ノイズ予測の方法

UNetベースモデル

  • 出力:UNetは画像と同じサイズのノイズを予測します。
conditional.py
def forward(self, x, timesteps, labels=None):
    # ...(中略)...
    x = self.out(x)  # 最終的にノイズの予測を出力
    return x

Transformerベースモデル

  • 出力:Transformerはパッチ化されたデータを元の画像サイズに復元し、ノイズを予測します。
dit.py
def forward(self, x, t, y):
    # ...(中略)...
    x = self.final_layer(x, c)  # パッチごとの予測
    x = self.unpatchify(x)  # パッチを結合して元の画像サイズに復元
    return x

まとめ

  • 類似点

    • 両モデルともディフュージョンプロセスを用いて画像生成を行います。
    • 時間情報とラベル情報を埋め込みベクトルとしてモデルに組み込んでいます。
  • 違い

    • モデル構造
      • UNetベースは畳み込みニューラルネットワークで、画像の局所的な特徴を捉えるのが得意です。
      • Transformerベースは自己注意機構を使い、画像内の広範な関係性を学習できます。
    • データの処理方法
      • UNetは画像全体を処理します。
      • Transformerは画像をパッチに分割して処理します。

結論

UNetベースとTransformerベースのディフュージョンモデルは、共通の目的(画像生成)のために異なるアプローチを取っています。UNetは畳み込みを用いて局所的な特徴を捉え、Transformerは自己注意機構を用いてグローバルな関係性を学習します。

とのこと。

やってみてやはりただ作れるのとモデルを仮説を持って改善していけるのは全然違うと感じた。
モデル開発のつらみも色々わかったので今後はwandbなどうまく使っていきたい。

とりあえず、収穫としてはたまに聞く「unetとtransformerそんなに変わらない」という言説がなんとなく理解できるようになったのと、画像モデルも言語モデルも仕組みはそんなに変わらないのがコードレベルで理解できた。

あと最新の論文のサマリーを読んでもコードレベルでイメージできることが多くなったと思う。

やっとスタート地点に立てた気がするのでこれからなお精進したい。

以上。

Discussion