GENIAC第2期成果/データグリッド:汎用画像生成基盤モデルの軽量化のためのTransformersベースLDM—LocalDiT—の開発
はじめに
こんにちは。
データグリッドのAIエンジニアの李です。
データグリッドでは、経済産業省およびNEDOが推進する日本の生成AIの開発力強化に向けたプロジェクト「GENIAC」第2期の支援のもと、Vision系基盤モデルの開発に取り組みました。
当社は2017年創業の京都大学発AIスタートアップとして、GAN(敵対的生成ネットワーク)や拡散モデルといった生成AIやそれらを活用した合成データをコア技術として、創業以来製造業をはじめとする多様な産業分野へAIデータソリューションを提供してまいりました。本プロジェクトでは、これまで培ってきた生成AIに関する知見や技術を活かし、ユーザーの意図を的確に反映できる動画・画像生成基盤モデルを開発いたしました。
今回はその成果の一部である「LocalDiT」(Huggingface)画像生成モデルの研究成果および開発過程で獲得した技術的知見について共有させていただきます。LocalDiTは、Stable DiffusionやPixArtなどの公開モデルやSOTA(State-Of-The-Art)技術に匹敵する性能を維持しつつ、計算コストを大幅に削減した軽量化モデルの実現を目指して開発した独自アーキテクチャです。
LocalDiTについて
LocalDiTは、PixArt-αをベースとした0.52Bパラメータの画像生成Diffusion Transformerモデルです。従来のグローバル注意機構に代わりに局所的注意機構(Local Attention)を導入することで、計算効率の向上とパラメータ数の削減を実現しています。テキストエンコーダにはFLAN-T5-XXL(4.3B)を採用し、VAEにはSDXLのVAEを使用しています。約39M枚の画像と対応する英文キャプションを用い、256→512→1024ピクセルへと段階的に解像度を上げながら学習を行いました。
Latent Diffusion Model
タイトルにも含まれているLatent Diffusion Model (以下、LDM)は、画像生成の効率と品質を両立させるために設計されたDiffusionモデルの一種です。従来のDiffusionモデルがピクセル空間で直接動作するのに対し、LDMはVAE(Variational AutoEncoder)によって圧縮された「潜在空間(latent space)」で処理を行う点が特徴的です。また、テキストやガイド画像、その他のモダリティによる条件付けが柔軟にできる特徴もあり、近年様々な生成AI産業で応用されている技術の一つです。
以下、LDMの大まかな動作原理です。
-
VAEによる次元圧縮:まず画像を「encoder」で低次元の潜在表現(latent representation)に変換します。一般的に使われているVAEは画像の解像度を1/8にするので、例えば256x256サイズの画像の場合、32x32の潜在表現にエンコードされます。
-
潜在空間でのDiffusionプロセス:圧縮された潜在空間内でノイズ追加と除去のプロセスを学習します。学習が終わったモデルはランダムノイズから徐々に意味のある潜在表現を生成できるようになります。このDiffusionプロセスを学習するモデルの構造によって様々な技術に分類されます。
- UNetベース:Stable Diffusion (SD) v1、SD v2、SDXL等
- Transformerベース:Flux、Diffusion Transformers (DiT)、PixArt等
-
VAEによる画像復元:Diffusionプロセスによって生成された潜在表現を「decoder」を通して高次元の画像へ変換します。
軽量化のための適用技術
LocalDiTの主要な技術的特徴は、局所的注意機構(Local Attention)の導入です。従来のTransformerアーキテクチャでは、全てのトークン間の関係を計算するグローバルな自己注意機構を使用しており、これがモデルサイズと計算コストの主要な要因となっていました。
LocalDiTでは、画像を局所的なウィンドウに分割し、各ウィンドウ内でのみ注意機構を計算する方式を採用しました。具体的には、以下の技術を実装しています:
-
ウィンドウベースの局所的注意機構: 画像を小さな領域(ウィンドウ)に分割し、各領域内でのみ自己注意を計算することで、計算量を大幅に削減しています。
-
交互配置アーキテクチャ: モデル内のTransformerブロックを交互に配置し、奇数番目のブロックには局所的注意機構、偶数番目のブロックには通常の注意機構を実装することで、局所的な特徴と全体的な特徴の両方を捉える能力を維持しています。
-
効率的なパラメータ共有: モデル内の一部のパラメータを共有することで、さらにモデルサイズを削減しています。
これらの技術により、PixArt-αと比較して約20%のパラメータ削減と最大20%の推論速度向上を実現しました。
学習安定性
Diffusionモデルの学習は不安定になりがちですが、LocalDiTでは以下の工夫により学習の安定性を向上させました:
-
段階的な解像度増加: 256ピクセルの低解像度から学習を開始し、徐々に512ピクセル、そして1024ピクセルへと解像度を上げていく段階的な学習戦略を採用しました。これにより、低解像度での基本的なパターン学習から高解像度での詳細な特徴の生成へと効率的に学習を進めることができました。
-
適応的学習率スケジューリング: 学習の進行に合わせて学習率を動的に調整することで、初期段階での急速な収束と後期段階での微調整を効果的に行いました。
-
勾配クリッピング: 極端な勾配値を制限することで、学習の安定性を向上させました。
-
Layer Normalizationの導入: Attention機構にLayer Normalizationを導入することで、特に低品質データを含む学習時のロスの不安定性を抑制しました。このロスの不安定性については後に詳しく解説いたします。
学習データについて
LocalDiTの学習には、約39M枚の高品質な画像とそれに対応する英文キャプションを使用しました。データセットは以下の特徴を持っています:
-
商用利用可能なオープンデータの活用: 商用利用が可能なPixelprose、common-catalog-cc-by、pixabay、PD12Mなどのデータセットを使用しました。それぞれのデータセットの特徴と大まかな数は以下の通りです。
- Pixelprose:多様なジャンルの高品質写真、約17M枚
- common-catalog-cc-by:クリエイティブコモンズライセンスの多様な画像、約9M枚
- Pixabay:プロフェッショナル品質の写真及びイラスト、約1M枚
- PD12M:パブリックドメインの1200万枚の画像コレクション、約12M枚
-
段階的データ戦略: 学習フェーズごとに異なるデータセットを使用しました。
- 256ピクセル学習: Pixelprose、common-catalog-cc-by
- 512ピクセル学習: pixabay、PD12M
- 1024ピクセル学習: pixabayのみ
この戦略を採用した理由は、高解像度学習になるほど画像品質の重要性が増すためです。pixabayは特に高品質な画像が多く含まれていることが知られています。また、PD12Mには古い時代の画像が多く含まれているため、1024ピクセル学習ではその影響を避けるためpixabayのみを使用しました。それぞれのフェーズの学習ステップは基本的に検証用の画像のクオリティを定期的にチェックしてこれ以上クオリティの改善が期待できなくなった時点でストップをかけ、次のフェーズへ移りました。各学習フェーズで使用したパラメータ等を以下のテーブルにまとめました。
Training phase Image resolution Batch size Num GPUs Steps Phase 1 256x256 160 H200 160 220k Phase 2 512x512 64 H200 160 80k Phase 3 1024x1024 8 H200 160 20k -
データ前処理: 元画像の中央を基準に1:1比率でクロップし、各解像度にリサイズしました。この処理によりアスペクト比の一貫性を保ちつつ、重要な被写体を中心に維持することができました。
-
効率的なデータパイプライン: WebDatasetフォーマットを採用し、NVIDIA DALIを使用した高速データローディングパイプラインを構築しました。これにより、I/Oボトルネックを最小限に抑え、GPUの稼働率を向上させることができました。
ただし、全解像度のデータを個別に保存する方式を採用したため、ディスク容量の不足や、データ移行に多大な時間を要するなどの課題も生じました。今後の改善点として、最高解像度(1Kや2K)のデータセットのみを保存し、学習時に動的にリサイズする戦略も検討価値があると考えています。
学習過程における特筆すべき観察点
学習過程で以下のような興味深い現象が観察されました:
-
データ品質の影響: common-catalog-cc-byデータセットを含めた学習では、一定ステップごとにロスが急激に変動する現象(上記の図では、4800 step付近)が見られました。このデータセットには、Webページのスクリーンショット、ポスター画像、文章のみのスクリーンショットなど、実写画像生成という目的にとっては「ノイズ」となるデータが多く含まれていました。こうしたデータがモデルの画像-テキスト関連性学習を妨げていたと考えられます。
この問題に対してはAttention機構にLayer Normalizationを導入することで改善を図りましたが、根本的な解決策としては高品質データのみを選別することの重要性が再確認されました。Layer Normalizationを導入する前後のロスの変動を図示した上記のグラフを見ると、両方とも5000ステップ当たりでロスが急激に下がっていますが、Layer Normalizationの実装がないモデルの場合、そのあとロスが発散しており、Layer Normalizationを入れたモデルの場合、以前のロスに回復することがわかります。実際検証用で出力される画像も、Layer Normalizationが無いときはノイズしか出力しなくなる反面、Layer Normalizationがある場合は以前の画像を出力できるようになります。この、(1)一旦ロスが下がる現象と、(2)その後ロスが発散するか安定するかの現象は、Layer Normalizationの有無の効果と学習画像のクオリティの問題と合わせてより研究が必要な点です。
-
高解像度学習時のロス特性: 256ピクセルと512ピクセルの学習では比較的安定していたロスの分散が、1024ピクセル学習では大きく増大する傾向が見られました。興味深いことに、ロスの移動平均値自体は256ピクセル学習で収束して以降、解像度を上げても大きく変動しませんでした。上記の図に現在公開したモデルの全学習過程のロスの履歴を図示しました。256x256不学習フェーズと512x512学習フェーズに比べて明らかにロスのばらつきが激しくなったことが分かります。
この現象の原因としては、以下の要因が考えられます:
- 高解像度になるほど生成すべき細部情報が増え、タスクの複雑性が上がる
- クロップとリサイズの処理による情報損失が高解像度では顕著になる
- バッチサイズの減少(高解像度ほどメモリ制約により小さくなる)による勾配推定の不安定化
- 使用データセットの変更(pixabayのみ)による分布シフト
学習効率について
学習効率を向上させるために、以下の技術を適用しました:
-
混合精度学習: 計算効率を高めるために16ビット(BFloat16)の混合精度学習を採用しました。
-
分散学習の最適化: 複数のGPUにわたる効率的な分散学習を実現するために、DeepSpeedフレームワークを活用しました。
-
WebDatasetとDALIの活用: 大規模データセットの効率的な処理のため、WebDatasetフォーマットでデータを保存し、NVIDIA DALIを用いた高速データローディングを実装しました。これにより、I/Oボトルネックを解消し、GPU利用効率を最大化しました。
学習結果まとめ
LocalDiTは、パラメータ数の大幅な削減と計算効率の向上にもかかわらず、PixArt-αに匹敵する画像生成品質を達成しました。特に以下の点で優れた結果を示しています:
-
画像品質: テキストプロンプトに基づく高品質な画像生成を実現しています。
-
計算効率: 推論速度が大幅に向上し、メモリ使用量も削減されています。
-
多様性: 様々なスタイル、構図、シーンの生成に対応可能です。
-
スケーラビリティ: 256から1024ピクセルの様々な解像度での画像生成に対応しています。
モデルサイズの比較
Method | #Params | #Images | GPU days |
---|---|---|---|
DALL·E | 12.0B | 250M | - |
DALL·E 2 | 6.5B | 650M | - |
Stable Diffusion v1.5 | 0.9B | 2000M | 6,250 A100 |
Stable Diffusion v2.1 | 0.8B | 2000M | ~ 6,000 A100 |
Stable Diffusion XL | 10.1B | - | ~ 6,000 A100 |
PixArt-α | 0.6B | 25M | 753 A100 |
LocalDiT | 0.5B | 39M | 448 H200 |
* NOTE: Stable Diffusion XLのパラメータ数は、SDXL baseの3.5BとSDXL refinerの6.6Bの合計です。
出力画像サンプル
学習したLocalDiTを使って1024x1024の画像を生成したサンプルを共有します。下のテーブルは4種類のプロンプトに対して生成された画像の例です。また、生成された画像に対するプロンプトの応答を太文字にして表示しておきました。今回学習したモデルは実際写真で撮ったようなリアルなテクスチャの画像を生成するのが特徴で、画像内のシーンや雰囲気も入力したプロンプトに合わせて生成されることがわかります。
Sample 1 | Sample 2 | Sample 3 | Sample 4 | |
---|---|---|---|---|
Image | ![]() |
![]() |
![]() |
![]() |
Prompt | Cinematic photograph of an ancient castle in misty jungle treetops. | Split-second capture of lightning bolt striking rocky desert mesa under stormy sky. | Ultra realistic macro photograph of a vintage green glass bottle covered in cold condensation droplets, backlit to reveal tiny air bubbles in the glass. | Calm coastal scene of a sea otter floating on its back, cracking a shell with a rock while kelp drifts around. |
T2I-CompBench評価結果
今回の開発のベンチマークターゲットであったPixArtの論部にも使用されたT2I-CompBenchによる定量的評価を行いました。この評価によりLocalDiTの複合的なテキスト指示への対応能力、属性バインディング、空間関係の理解度などを客観的に測定します。評価対象は広く使われているStable Diffusion XL(SDXL)、OpenAIのDalle-2、PixArt-αと我々のLocalDiTです。
Model | Color ↑ | Shape ↑ | Texture ↑ | Spatial ↑ | Non-Spatial ↑ | Complex ↑ |
---|---|---|---|---|---|---|
SDXL | 0.6369 | 0.5408 | 0.5637 | 0.2032 | 0.3110 | 0.4091 |
Dalle-2 | 0.5750 | 0.5464 | 0.6374 | 0.1283 | 0.3043 | 0.3696 |
PixArt-α | 0.6886 | 0.5582 | 0.7044 | 0.2082 | 0.3179 | 0.4117 |
LocalDiT (ours) | 0.6567 | 0.4714 | 0.6710 | 0.1360 | 0.2897 | 0.3491 |
評価結果から、LocalDiTはテクスチャ(Texture)と色彩(Color)の表現において比較的高いスコアを示しており、特にテクスチャ表現においてはDalle-2を上回る性能を達成しています。一方で、形状(Shape)や空間関係(Spatial)、複合的な指示(Complex)の処理においては改善の余地があることが分かります。
これらの結果は、局所的注意機構を採用したことによる特性と考えられます。局所的な処理はテクスチャのような局所的特徴の捉え方が優れている一方、オブジェクト全体の形状や複数オブジェクト間の関係理解には課題があることを示唆しています。
スクラッチから学習を開始したモデルとしては良好な結果であり、特にテクスチャ表現の強みを活かした用途(マテリアル生成、テクスチャリング等)において優位性を発揮できる可能性があります。今後の改良により、形状理解や空間関係の把握能力を向上させることで、より汎用的な画像生成モデルへと進化させることが課題です。
まとめ
LocalDiTは、最先端の画像生成技術を軽量化・効率化することで、より広い用途での活用を可能にするモデルです。局所的注意機構の導入により、計算コストを大幅に削減しながらも高品質な画像生成能力を維持することに成功しました。
学習過程から得られた知見としては、データセット品質の重要性、特に高解像度学習における厳選されたデータの影響が挙げられます。また、効率的なデータ処理パイプラインの構築が、大規模モデル学習の成功に不可欠であることも再確認されました。
T2I-CompBenchの評価結果からは、テクスチャ表現に強みを持つ一方で、オブジェクト生成や空間関係の理解に課題があることが分かりました。これらの知見は今後のモデル改良における重要な指針となります。
今後の研究方向としては、さらなるモデルの軽量化、多言語対応の拡充、形状理解能力の向上、より高解像度の画像生成能力の向上などが考えられます。また、実世界のアプリケーションへの統合を進め、クリエイティブ産業やデザイン、コンテンツ制作など様々な分野での活用を目指していきます。
Discussion