拡散言語モデルの仕組み [論文より]

に公開

はじめに

画像認識/生成に興味があり、最近は拡散モデルを学んでいました。
勉強していたときに、「画像の生成方法はわかった。けど、言語も拡散モデルで生成してるのあったよな...」と思い、調べてみました。

現在、呼び名はあまり統一されてなく、Large Language Diffusion Models (大規模言語拡散モデル)、Diffusion Language Models (拡散言語モデル)、diffusion Large Language Model (dLLM: 拡散型大規模言語モデル)などと呼ばれています。
2025年2月末に、Inception LabsのMercury Corderが話題となりましたね。既存のLLMとは異なるアプローチを取り、秒間1000トークンの生成で爆速生成を実現していました。また、LLaDA(Large Language Diffusion with mAsking)という拡散言語モデルも発表されています。

今回は、LLaDAを紹介している論文である、『Large Language Diffusion Models』を読んで調べたことをまとめてみます。
※ 素人による調査なので間違っている場合があります。

Large Language Diffusion Models

1章: Introduction (研究のモチベーション)

問: LLMが示す知能を実現するには、自己回帰パラダイムを使うしかないのか?
[1]では、この答えは「単純にYesではない」といっていました。

現在主流のアプローチはTransformerなどの自己回帰モデリング(ARM)ですね。ARMとは、いわゆる次トークンを予測→また予測→またまた予測...のようにしていくものです。
これは非常に効果的であり、現在のLLMの基盤となっていますね。

でも、これはあくまで、やってみた結果良かったよね、であって、ARMがLLMの本質を捉えているとは限りません。
LLMの本質的な性質を支えているのは、生成モデリングの原理(最大尤度推定)であると考えられています。
さらに、指示追従やインコンテキスト学習の能力も、構造的な一貫性を持つ言語タスクに対する適切な条件付き生成モデル全般の本質的性質であり、ARMだけの特権ではありません。

このことから、この論文ではLLaDA (Large Language Diffusion with mAsking)を提案しています。これは、LLMsが示す能力がARM以外の生成モデリング原理からも現れるかどうかを検証し、先に述べた根本的な問いに答えようとするものです。
LLaDAはMasked Diffusion Model(MDM)という離散的なランダムマスキングプロセスを組み込み、その逆過程を近似するマスク予測器を訓練します。この設計により、LLaDAは双方向依存性を持つモデル分布を構築し、対数尤度 (のlower bound)を最適化することで、既存LLMsとは異なるアプローチを提供します。

LLaDAでは、データ準備→事前学習→教師ありファインチューニング (SFT)→評価という標準的なパイプランを採用し、8B規模の拡散言語モデルへとスケールさせました。事前学習では2.3兆トークンを使用され、その後450万ペアのデータでSFTされています。
LLaDAは言語理解、数学、コード生成、中国語などの様々なタスクに置いて、

スケーラビリティ、インコンテキスト学習、指示追従性において優れた性能を見せています。また、Reversal Curse (逆転の呪い)という、「AはBである」というデータから、その逆「BはAである」というパターンを自動的に推測できないという現象をも克服したと述べています。

2章: Approach (仕組み)

2.1: 定式化

forward processreverse processというものを通じて、モデル分布p_\theta (x_0)を取得します。
forward processでは、文章x_0内の各トークンを徐々にマスクし、t=1で完全にマスクされます (t\in(0, 1))。このとき、x_tは部分的にマスクされていて、各トークンは確率tでマスク、1-tで非マスク状態となっています。
reverse processでは、t=1 \to 0へと動かしながら、マスクされたトークンを逐次予測してデータ分布を復元します。

ここではマスク予測器p_\theta(\cdot|x_t)が核となります。これはx_tをインプットとして受け取ったときに、マスクされているトークンMを同時に予測するものとなります。

\mathcal{L}(\theta) \triangleq -\mathbb{E}_{t, x_0, x_t}\left[\frac{1}{t} \sum_{i=1}^L \bm{1}[x_t^i=M] \log p_\theta(x_0^i|x_t)\right]

ここで、x_0は訓練データからサンプルされた文章であり、tは[0, 1]の一様分布からサンプリングされたものであり、x_tはforward processによりサンプリングされた文章です。また、\bm{1}[\cdot]は、x_t^i$ (x_ti文字目)がマスクされていた場合のみ1、それ以外は0となり、マスクされたトークンについてのみクロスエントロピー損失に含むようにする指示関数です。

訓練後はマスク予測器p_\theta(\cdot|x_t)によりreverse processを実行し、t=0のときの周辺分布としてp_\theta(x_0)を定義します。\mathcal{L}(\theta)の損失がモデル分布の負対数尤度の上界であることが証明されている点です。

-\mathbb{E}_{p_{\text{data}}(x_0)}[\log p_\theta(x_0)] \leq \mathcal{L}(\theta)
上式の意味

生成モデリングの原理によると

\max_{\theta} \mathbb{E}_{p_{\text{data}}(x)} \log p_{\theta}(x) \iff \min_{\theta} \mathrm{KL}\left( p_{\text{data}}(x) \,\|\, p_{\theta}(x) \right)

ここで、p_{\text{data}}(x)は真のデータに対する分布です。

-\mathbb{E}_{p_{\text{data}}(x_0)}[\log p_\theta(x_0)] \leq \mathcal{L}(\theta)

の左辺は生成モデリングの原理により、p_{\text{data}}(x) = p_{\theta}(x)となるときに最小となるので、これを目指しています。しかし、p_{\text{data}}(x)は神のみぞ知るモデルなので、実際に計算はできません。そのため左辺を最小化する代わりに、\mathcal{L}(\theta)を最適化して上界を求め、p_{\text{data}}(x)に近くなるようなp_{\theta}(x)を求めるためのパラメータ\thetaを求めることに繋がります。

2.2 Pre-training

LLaDAはマスク予測器としてTransformerを用いていますが、Causal mask (Causal Attention, Masked Attention)を利用してないそうです。これは、既存LLMは次トークン予測なので未来のトークンを見ないように採用されてましたが、LLaDAは予測時に入力全体を参照し、穴埋めをする方式だからです。

LLaDA 8BはLLaMA3 8Bと比較しやすいように、ハイパーパラメータは揃えつつも、必要最低限の変更は入れています。例えばLLaDAはKVキャッシュ(Multi Head Attentionを省メモリ化するための工夫)との互換性がないため、Grouped Query Attentionではなく標準のMulti Head Attentionを採用しています。その影響でAttention層のパラメータが増えるので、FFN次元を減らしてモデルサイズを8Bに保っています。トークナイザーも異なるそうです。

LLaDAモデルは2.3兆トークンからなるデータセットで事前学習されていて、データプロトコルも既存のLLMとほぼ変わらないそうです。特別な手法は使わず、データはオンラインコーパスから取得し、低品質な内容は手動設計ルールやLLMベースの手法で除去されています。

一般テキストに加え、高品質なコード、数学、多言語データも含まれています。データソースやドメインの混合比は小規模なARMでガイドされています。事前学習は4096トークンの固定系列長で行い、総計算コストは0.13万H800 GPU時間らしく、同規模・同データセットのARMと同程度だそうです。

手法的にも簡単で、トレーニングデータx_0に対してt\in[0, 1]をランダムサンプリングし、各トークンを確率tで独立にマスクしてx_tを得ます。モンテカルロ法を用いて\mathcal{L}(\theta)を推定し、確率的勾配降下法で訓練します。
また、4096トークンの固定系列長だけではなく、事前学習データの1%はランダム長([1, 4096])になるようにしているそうです。

学習率のスケジューラーには、Warmup-Stable-Decayというのを用いていて、これはその名の通り、最初は徐々に学習率を上げていき(Warmup)、中間は固定し(Stable)、後半は徐々に減衰させていく(Decay)方法です。

2.3. Supervised Fine-Tuning (SFT)

既存のLLMと同じようにLLaDAでもInstruction Tuningを行います。データセットは、(p_0, r_0)です。ここで、p_0はプロンプト、r_0はレスポンスを示します。
今までは、いうなれば尤度p_\theta(x_0)を最大化するような学習をしていましたが、今回は条件付き確率p_\theta(r_0|p_0)をモデリングします。


上図のように、tにかかわらずプロンプトは固定し、レスポンスのみを独立にマスクします。その後、プロンプトとマスク済みレスポンスr_tを事前学習済みのマスク予測器に入力し、SFT用の損失を計算します。

-\mathbb{E}_{t, p_0, r_0, r_t}\left[\frac{1}{t} \sum_{i=1}^{L'} 1[r_t^,i=M] \log p_\theta(r_0^i|p_0, r_t)\right]

これは事前学習とほとんど同じです。p_0r_0の連結はx_0p_0r_tの連結はx_tとみなせますね、ただし、x_tとの違いは、マスクされるトークンがすべてレスポンス部分になる、という点のみです。

LLaDAはこれを、450万ペアのデータセットで行ったそうです。特に特別な手法は導入されていません。
データ長はばらばらになるので、パディングとして[EOS]トークンを付与し、全データの長さを揃えています。数式のL'は[EOS]により可変になったものを表しています。
[EOS]は訓練時は通常のトークン特別せずに扱われますが、サンプリング時(推論時)に除去することで、LLaDAが応答長を自動で制御できるようになっています。

学習率のスケジューラーは、Pre-trainingと同じです。3エポックでの学習だそうです。

2.4 Inference (推論)

学習も終わり、いよいよ文章を生成する段階です!


まずはサンプリングです。上図のようにプロンプトp_0が与えられた場合、レスポンスが完全にマスクされた状態(t=1)から始まります。t=1の状態からreverse processを行います。サンプリングステップ(何ステップで生成完了とするか)はハイパーパラメータであり、レイテンシーと生成テキストの品質のトレードオフになります。
デフォルトでは等間隔の離散タイムステップを用います。また、生成長もハイパーパラメータではありますが、最終結果はこの長さパラメータにほとんど依存しないという実験結果が出ています。

ここでは中間ステップt\in (0,1]からs\in[0, t)への遷移を考えましょう。sは、reverse processにおけるtの次のタイムステップ(s<t)です。
p_0r_tをマスク予測器に入力し、すべてのマスクされたトークンを同時に予測します。その後、予測されたトークンのうちs/tの割合を再度マスク(Remask)してr_sとし、reverse processの遷移がforward processの逆をたどるようにしています。徐々にテキストが埋まっていく感じですね。

ここまでテキストで説明してきましたが、エンジニアの方は下記がわかりやすいのではないでしょうか。

原則、forward processの逆プロセスなので、reverse processのRemaskは完全にランダムであるべきですが、いくつかのRemasking戦略も検討されているそうです。具体的には下記の戦略があります。

  • 予測信頼度が最も低いトークンを再マスクする「low-confidence remasking」
  • テキスト系列を複数ブロックに分割して左から右に生成する「semi-autoregressive remasking」

3.Experiments

下流タスクにおけるスケーラビリティを、LLaDAと自前で構築したARMベースラインとで比較する。特に1B規模では、LLaDAとARMが同じアーキテクチャ、データ、その他すべての設定を共有しています。8Bタスクでは制約上、若干異なるサイズ同士の比較となります。

スケーラビリティについては下図の通りで、全体的な傾向はARMと同水準になっています。PIQAのようなタスク(言語モデルが物理的な世界についてどの程度学習しているかを調査する質問が含まれている。選択肢から最も適切な答えを選ぶ。物理的な世界とは...?)でも、パフォーマンスが遅れをとるものの、スケールが大きくなるにつれてARMとの差が縮まっています。

ベンチマークについても、LLaMA3 8Bには劣っていますが、大きくは劣ってない状況が見て取れます。数学や中国語タスクで優位性があるようです([1]の表2)。

また、Reversal Curse (逆転の呪い)を効果的に克服し、順方向・逆方向の両タスクで一貫したゼロショット性能を示しました。一方、Qwen 2.5やGPT-4oはいずれも両者の間に大きなギャップを示してます([1]の表3)。

指示追従に関してです。LLaDAが非自己回帰的な方法で一貫性・流暢性・長文生成能力を持つことが示されています。マルチターンダイアログ能力に関しても、会話履歴を的確に保持し、複数言語にわたって文脈に適した応答を生成できています。

おわりに

今回は拡散モデルに興味を持った流れで、拡散言語モデルの論文である"Large Language Diffusion Models"を読んでみました!まだまだGPTなどには及ばなそうですが、研究が進めば高速で精度の良いモデルが爆誕するかもしれませんね!

参考文献

[1] Shen Nie et al. "Large Language Diffusion Models" arXiv:2502.09992v2 2025
https://arxiv.org/abs/2502.09992v2

画像に関する拡散モデルの解説書
[2] 斎藤 康毅. ゼロから作るDeep Learning ❺ 生成モデル編. O'Reilly Japan. 2024
https://www.oreilly.co.jp/books/9784814400591/

Discussion