拡散言語モデルの仕組み [論文より]
はじめに
画像認識/生成に興味があり、最近は拡散モデルを学んでいました。
勉強していたときに、「画像の生成方法はわかった。けど、言語も拡散モデルで生成してるのあったよな...」と思い、調べてみました。
現在、呼び名はあまり統一されてなく、Large Language Diffusion Models (大規模言語拡散モデル)、Diffusion Language Models (拡散言語モデル)、diffusion Large Language Model (dLLM: 拡散型大規模言語モデル)などと呼ばれています。
2025年2月末に、Inception LabsのMercury Corderが話題となりましたね。既存のLLMとは異なるアプローチを取り、秒間1000トークンの生成で爆速生成を実現していました。また、LLaDA(Large Language Diffusion with mAsking)という拡散言語モデルも発表されています。
今回は、LLaDAを紹介している論文である、『Large Language Diffusion Models』を読んで調べたことをまとめてみます。
※ 素人による調査なので間違っている場合があります。
Large Language Diffusion Models
1章: Introduction (研究のモチベーション)
問: LLMが示す知能を実現するには、自己回帰パラダイムを使うしかないのか?
[1]では、この答えは「単純にYesではない」といっていました。
現在主流のアプローチはTransformerなどの自己回帰モデリング(ARM)ですね。ARMとは、いわゆる次トークンを予測→また予測→またまた予測...のようにしていくものです。
これは非常に効果的であり、現在のLLMの基盤となっていますね。
でも、これはあくまで、やってみた結果良かったよね、であって、ARMがLLMの本質を捉えているとは限りません。
LLMの本質的な性質を支えているのは、生成モデリングの原理(最大尤度推定)であると考えられています。
さらに、指示追従やインコンテキスト学習の能力も、構造的な一貫性を持つ言語タスクに対する適切な条件付き生成モデル全般の本質的性質であり、ARMだけの特権ではありません。
このことから、この論文ではLLaDA (Large Language Diffusion with mAsking)を提案しています。これは、LLMsが示す能力がARM以外の生成モデリング原理からも現れるかどうかを検証し、先に述べた根本的な問いに答えようとするものです。
LLaDAはMasked Diffusion Model(MDM)という離散的なランダムマスキングプロセスを組み込み、その逆過程を近似するマスク予測器を訓練します。この設計により、LLaDAは双方向依存性を持つモデル分布を構築し、対数尤度 (のlower bound)を最適化することで、既存LLMsとは異なるアプローチを提供します。
LLaDAでは、データ準備→事前学習→教師ありファインチューニング (SFT)→評価という標準的なパイプランを採用し、8B規模の拡散言語モデルへとスケールさせました。事前学習では2.3兆トークンを使用され、その後450万ペアのデータでSFTされています。
LLaDAは言語理解、数学、コード生成、中国語などの様々なタスクに置いて、
スケーラビリティ、インコンテキスト学習、指示追従性において優れた性能を見せています。また、Reversal Curse (逆転の呪い)という、「AはBである」というデータから、その逆「BはAである」というパターンを自動的に推測できないという現象をも克服したと述べています。
2章: Approach (仕組み)
2.1: 定式化
forward processとreverse processというものを通じて、モデル分布
forward processでは、文章
reverse processでは、
ここではマスク予測器
ここで、
訓練後はマスク予測器
上式の意味
生成モデリングの原理によると
ここで、
の左辺は生成モデリングの原理により、
2.2 Pre-training
LLaDAはマスク予測器としてTransformerを用いていますが、Causal mask (Causal Attention, Masked Attention)を利用してないそうです。これは、既存LLMは次トークン予測なので未来のトークンを見ないように採用されてましたが、LLaDAは予測時に入力全体を参照し、穴埋めをする方式だからです。
LLaDA 8BはLLaMA3 8Bと比較しやすいように、ハイパーパラメータは揃えつつも、必要最低限の変更は入れています。例えばLLaDAはKVキャッシュ(Multi Head Attentionを省メモリ化するための工夫)との互換性がないため、Grouped Query Attentionではなく標準のMulti Head Attentionを採用しています。その影響でAttention層のパラメータが増えるので、FFN次元を減らしてモデルサイズを8Bに保っています。トークナイザーも異なるそうです。
LLaDAモデルは2.3兆トークンからなるデータセットで事前学習されていて、データプロトコルも既存のLLMとほぼ変わらないそうです。特別な手法は使わず、データはオンラインコーパスから取得し、低品質な内容は手動設計ルールやLLMベースの手法で除去されています。
一般テキストに加え、高品質なコード、数学、多言語データも含まれています。データソースやドメインの混合比は小規模なARMでガイドされています。事前学習は4096トークンの固定系列長で行い、総計算コストは0.13万H800 GPU時間らしく、同規模・同データセットのARMと同程度だそうです。
手法的にも簡単で、トレーニングデータ
また、4096トークンの固定系列長だけではなく、事前学習データの1%はランダム長(
学習率のスケジューラーには、Warmup-Stable-Decayというのを用いていて、これはその名の通り、最初は徐々に学習率を上げていき(Warmup)、中間は固定し(Stable)、後半は徐々に減衰させていく(Decay)方法です。
2.3. Supervised Fine-Tuning (SFT)
既存のLLMと同じようにLLaDAでもInstruction Tuningを行います。データセットは、
今までは、いうなれば尤度
上図のように、
これは事前学習とほとんど同じです。
LLaDAはこれを、450万ペアのデータセットで行ったそうです。特に特別な手法は導入されていません。
データ長はばらばらになるので、パディングとして[EOS]トークンを付与し、全データの長さを揃えています。数式の
[EOS]は訓練時は通常のトークン特別せずに扱われますが、サンプリング時(推論時)に除去することで、LLaDAが応答長を自動で制御できるようになっています。
学習率のスケジューラーは、Pre-trainingと同じです。3エポックでの学習だそうです。
2.4 Inference (推論)
学習も終わり、いよいよ文章を生成する段階です!
まずはサンプリングです。上図のようにプロンプト
デフォルトでは等間隔の離散タイムステップを用います。また、生成長もハイパーパラメータではありますが、最終結果はこの長さパラメータにほとんど依存しないという実験結果が出ています。
ここでは中間ステップ
ここまでテキストで説明してきましたが、エンジニアの方は下記がわかりやすいのではないでしょうか。
原則、forward processの逆プロセスなので、reverse processのRemaskは完全にランダムであるべきですが、いくつかのRemasking戦略も検討されているそうです。具体的には下記の戦略があります。
- 予測信頼度が最も低いトークンを再マスクする「low-confidence remasking」
- テキスト系列を複数ブロックに分割して左から右に生成する「semi-autoregressive remasking」
3.Experiments
下流タスクにおけるスケーラビリティを、LLaDAと自前で構築したARMベースラインとで比較する。特に1B規模では、LLaDAとARMが同じアーキテクチャ、データ、その他すべての設定を共有しています。8Bタスクでは制約上、若干異なるサイズ同士の比較となります。
スケーラビリティについては下図の通りで、全体的な傾向はARMと同水準になっています。PIQAのようなタスク(言語モデルが物理的な世界についてどの程度学習しているかを調査する質問が含まれている。選択肢から最も適切な答えを選ぶ。物理的な世界とは...?)でも、パフォーマンスが遅れをとるものの、スケールが大きくなるにつれてARMとの差が縮まっています。
ベンチマークについても、LLaMA3 8Bには劣っていますが、大きくは劣ってない状況が見て取れます。数学や中国語タスクで優位性があるようです([1]の表2)。
また、Reversal Curse (逆転の呪い)を効果的に克服し、順方向・逆方向の両タスクで一貫したゼロショット性能を示しました。一方、Qwen 2.5やGPT-4oはいずれも両者の間に大きなギャップを示してます([1]の表3)。
指示追従に関してです。LLaDAが非自己回帰的な方法で一貫性・流暢性・長文生成能力を持つことが示されています。マルチターンダイアログ能力に関しても、会話履歴を的確に保持し、複数言語にわたって文脈に適した応答を生成できています。
おわりに
今回は拡散モデルに興味を持った流れで、拡散言語モデルの論文である"Large Language Diffusion Models"を読んでみました!まだまだGPTなどには及ばなそうですが、研究が進めば高速で精度の良いモデルが爆誕するかもしれませんね!
参考文献
[1] Shen Nie et al. "Large Language Diffusion Models" arXiv:2502.09992v2 2025
画像に関する拡散モデルの解説書
[2] 斎藤 康毅. ゼロから作るDeep Learning ❺ 生成モデル編. O'Reilly Japan. 2024
Discussion