📖

DPOとGRPO:PPO以降の手法1:強化学習コース(12/N)

に公開

はじめに

前回の記事では、PPO(Proximal Policy Optimization)について詳しく学びました。TRPOの理論的基盤から始まり、PPOがクリッピング機構によってどのように実装の簡素化を実現したかを理解しました。また、LLMのRLHFにおけるPPOの重要な役割についても確認しました。

本記事では、PPOの限界を克服するために開発された最新の手法であるDPO(Direct Preference Optimization)GRPO(Group Relative Policy Optimization) について解説します。これらの手法は、従来のアプローチとは根本的に異なる発想で、より効率的で実用的な学習を実現しています。

これまでのシリーズで学んだ内容は以下の通りです。

  • 第1回 - 強化学習の全体像と問題設定
  • 第2回 - 環境側の定式化(マルコフ決定過程)
  • 第3回 - エージェント側の概念とプランニングアルゴリズム(価値反復法)
  • 第4回 - モデルフリー学習とモンテカルロ法
  • 第5回 - 時間差分学習とTD誤差
  • 第6回 - 多段階TD学習とEligibility Trace
  • 第7回 - 関数近似による価値関数学習
  • 第8回 - 方策勾配法の基礎理論
  • 第9回 - Actor-Critic法
  • 第10回 - LLMと強化学習の融合
  • 第11回 - PPO(Proximal Policy Optimization)
  • 第12回(本記事) - DPOとGRPO:PPOの限界を超える最新手法

PPOの課題と限界

従来のRLHFパイプラインの複雑性

これまで学んだように、従来のRLHFは以下の3段階で構成されます。

  1. 教師あり微調整(SFT) - 高品質なデータでモデルを初期化
  2. 報酬モデルの学習 - 人間の好みを予測するモデルを構築
  3. 強化学習による最適化 - PPOで方策を改善

この3段階のプロセスには以下の問題があります。

報酬モデルの不安定性

報酬モデルr_\phi(x, y)は、人間の好みデータから学習されますが、いくつかの重要な課題があります。

まず分布シフトの問題があります。PPOによる学習が進むと生成される文章の分布が変化し、報酬モデルの予測精度が低下します。これは、報酬モデルが学習時のデータ分布と異なる分布の入力に対して適切な評価ができなくなるためです。

また、過学習の問題も深刻です。報酬モデルが学習データに過度に適合し、新しいタイプの出力に対して不適切な評価をする可能性があります。特に学習データが限られている場合、この問題は顕著になります。

さらに報酬ハッキングの問題もあります。モデルが報酬モデルの脆弱性を悪用して、見かけ上高い報酬を得ながら実際には低品質な出力を生成することがあります。これは報酬モデルが人間の真の意図を完全に捉えきれていないために起こります。

PPOの計算コスト

PPOによる学習は計算コストが高いという問題もあります。

まず、方策ネットワーク\pi_\thetaと価値ネットワークV_\psiの両方を維持する必要があり、メモリ使用量が倍増します。特に大規模言語モデルの場合、この負担は無視できません。

さらに、各更新ステップで複数回の勾配更新をするため、計算時間が長くなります。PPOは同じデータに対して複数のエポックで学習しますが、これは計算効率の観点から見ると非効率的です。

また、バッチサイズとエポック数の調整が複雑で、ハイパーパラメータに敏感という問題もあります。最適な設定を見つけるために多くの実験が必要となり、開発コストが増大します。

DPO(Direct Preference Optimization)

DPOの基本アイデア

DPOは、2023年にStanford大学のRafailov et al.によって提案された革新的な手法です。その核心的なアイデアは「報酬モデルを介さずに、人間の好みデータから直接方策を最適化する」ことです。

理論的基盤:逆強化学習の視点

Bradley-Terryモデルの復習

人間の好みは、Bradley-Terryモデルで表現されます。

P(y_1 \succ y_2 | x) = \frac{\exp(r(x, y_1))}{\exp(r(x, y_1)) + \exp(r(x, y_2))} = \sigma(r(x, y_1) - r(x, y_2))

ここで、y_1 \succ y_2は「y_1y_2より好まれる」ことを表します。

最適方策と報酬の関係

KL制約下での強化学習の最適解は、解析的に求めることができます。目的関数が以下の場合を考えます。

\max_\pi \mathbb{E}_{x \sim D, y \sim \pi(y|x)}[r(x, y)] - \beta \mathcal{D}_{KL}[\pi(y|x) \| \pi_{ref}(y|x)]

この最適化問題の解は以下の形で表されます。

\pi^*(y|x) = \frac{1}{Z(x)} \pi_{ref}(y|x) \exp\left(\frac{1}{\beta} r(x, y)\right)

ここで、Z(x)は分配関数(partition function)です。

Z(x) = \sum_y \pi_{ref}(y|x) \exp\left(\frac{1}{\beta} r(x, y)\right)

報酬関数の逆算

上記の関係式を報酬について解くと以下のようになります。

r(x, y) = \beta \log \frac{\pi^*(y|x)}{\pi_{ref}(y|x)} + \beta \log Z(x)

Z(x)xのみに依存しyには依存しないため、Bradley-Terryモデルにおける確率比には影響しません。

P(y_1 \succ y_2 | x) = \sigma(\beta \log \frac{\pi^*(y_1|x)}{\pi_{ref}(y_1|x)} - \beta \log \frac{\pi^*(y_2|x)}{\pi_{ref}(y_2|x)})

DPOの目的関数

負の対数尤度損失

DPOは、人間の好みデータ\mathcal{D} = \{(x^{(i)}, y_w^{(i)}, y_l^{(i)})\}に対して、以下の負の対数尤度を最小化します。

\mathcal{L}_{DPO}(\pi_\theta) = -\mathbb{E}_{(x,y_w,y_l) \sim \mathcal{D}} \left[ \log \sigma \left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{ref}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{ref}(y_l|x)} \right) \right]

ここで以下のように定義されます。

  • y_wは好まれる応答(winner)
  • y_lは好まれない応答(loser)
  • \betaは温度パラメータ

勾配の直感的理解

DPOの勾配を計算すると以下のようになります。

\nabla_\theta \mathcal{L}_{DPO} = -\beta \mathbb{E}_{(x,y_w,y_l) \sim \mathcal{D}} \left[ \sigma(\hat{r}_\theta(x, y_l) - \hat{r}_\theta(x, y_w)) \left( \nabla_\theta \log \pi_\theta(y_w|x) - \nabla_\theta \log \pi_\theta(y_l|x) \right) \right]

ここで、\hat{r}_\theta(x, y) = \beta \log \frac{\pi_\theta(y|x)}{\pi_{ref}(y|x)}は暗黙的な報酬関数です。

この勾配の直感的な意味は以下の通りです。\sigma(\hat{r}_\theta(x, y_l) - \hat{r}_\theta(x, y_w))は重みであり、好まれない応答の暗黙的報酬が好まれる応答より高い場合により大きな更新をします。好まれる応答y_wの確率を増加させ、好まれない応答y_lの確率を減少させる方向に学習します。

DPOの実装上の利点

実装の簡素化

DPOはPPOと比較して実装が非常に簡単です。

最も大きな利点は、報酬モデルの学習が不要なことです。これにより、従来の3段階のプロセスが2段階(SFT + DPO)に短縮され、開発サイクルが大幅に短くなります。報酬モデルの不安定性に悩まされることもありません。

また、価値ネットワークの維持が不要なため、メモリ使用量が削減されます。PPOでは方策と価値の2つのネットワークが必要でしたが、DPOでは方策ネットワークのみで十分です。

さらに、ハイパーパラメータが少なく、調整が容易という利点もあります。PPOのような複雑なハイパーパラメータ調整は不要で、温度パラメータ\betaの設定が主な調整項目となります。

計算効率

DPOの各更新で必要な計算は方策の順伝播のみで、PPOのような複数回の勾配更新は不要です。バッチ処理が簡単で、効率的な学習が可能です。

安定性

報酬モデルの不安定性による影響を受けず、分布シフトの問題も軽減されます。直接的な最適化により、報酬ハッキングの問題も部分的に解決されます。

GRPO(Group Relative Policy Optimization)

GRPOの動機と基本概念

GRPOは、DeepSeekによって2024年に提案された手法で、PPOの価値ネットワークに関する課題を解決します。その核心的なアイデアは「同じプロンプトに対する複数の応答の統計を使用してAdvantage関数を推定する」ことです。

価値ネットワークの問題点

学習の不安定性

PPOにおける価値ネットワークV_\psi(s)の学習には複数の深刻な問題があります。

まず移動ターゲット問題があります。方策が変化するとターゲット値も変化し、価値関数の学習が不安定になります。価値ネットワークは常に移動する目標を追いかけることになり、収束が困難になります。

次に過学習の問題があります。価値ネットワークが特定の状態分布に過度に適合する可能性があり、汎化性能が低下します。特に、学習の初期段階では限られた状態しか観測されないため、この問題は顕著です。

さらに計算コストの問題も無視できません。方策ネットワークと同等のサイズのネットワークを追加で維持する必要があり、メモリと計算時間の両面で大きな負担となります。

LLMにおける特有の課題

大規模言語モデルでは、さらに特有の困難があります。

第一に、系列レベル報酬の問題があります。通常は系列の最後でのみ報酬が与えられるため、各トークン生成時点での価値を推定することが極めて困難です。これは信用割当問題を深刻化させます。

第二に、高次元の状態空間の問題があります。各トークンが状態を表し、語彙サイズが数万から数十万に及ぶため、状態空間が爆発的に大きくなります。価値ネットワークがこの巨大な空間を適切に近似することは非常に困難です。

第三に、長系列の問題があります。数百から数千トークンの長い系列を扱う必要があり、計算コストとメモリ使用量が系列長に比例して増大します。また、長い系列での価値推定の精度も低下しやすくなります。

GRPOのグループベースアプローチ

基本的なアルゴリズム

与えられた質問qに対して、G個の応答をサンプリングします。

\{o_1, o_2, \ldots, o_G\} \sim \pi_{\theta_{old}}(\cdot|q)

各応答o_iは報酬r_iを受け取ります。グループベースAdvantageは以下のように計算されます。

\hat{A}_i = \frac{r_i - \mu_G}{\sigma_G}

ここで以下のように定義されます。

  • \mu_G = \frac{1}{G}\sum_{j=1}^G r_j(グループ平均)
  • \sigma_G = \sqrt{\frac{1}{G-1}\sum_{j=1}^G (r_j - \mu_G)^2}(グループ標準偏差)

理論的正当性

このアプローチには強固な理論的基盤があります。

まず、大数の法則により、グループサイズGが大きいとき、グループ平均\mu_Gは期待値\mathbb{E}_{\pi_{\theta_{old}}}[r(o)]に収束します。これにより、真の期待報酬の良い推定値が得られます。

次に、分散削減の効果があります。標準化により各グループ内でゼロ平均、単位分散を保証することで、学習の安定性が向上します。これは価値ネットワークなしでも安定した学習を可能にする重要な要素です。

さらに、このアプローチは報酬モデルの学習方法(比較判断)と自然に整合します。人間の好みデータが比較形式で与えられることが多いため、グループ内での相対的な評価は直感的で解釈しやすいという利点があります。

GRPOの目的関数

完全な目的関数

GRPOは以下の目的関数を最適化します。

J_{GRPO}(\theta) = \mathbb{E}_{q \sim \mathcal{D}} \left[ \mathbb{E}_{\{o_i\}_{i=1}^G \sim \pi_{\theta_{old}}(\cdot|q)} \left[ \frac{1}{G} \sum_{i=1}^G \mathcal{L}_i(\theta) \right] \right] - \beta \mathcal{D}_{KL}[\pi_\theta \| \pi_{ref}]

ここで、\mathcal{L}_i(\theta)は個別の応答に対する損失です。

\mathcal{L}_i(\theta) = \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \min\left( r_{i,t}(\theta) \hat{A}_i, \text{clip}(r_{i,t}(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_i \right)

トークンレベル重要度比

r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{old}}(o_{i,t}|q, o_{i,<t})}

実装上の詳細

KLダイバージェンス正則化

GRPOでは、KLダイバージェンスの推定に以下の不偏推定量を使用します。

\hat{\mathcal{D}}_{KL}[\pi_\theta \| \pi_{ref}] = \frac{\pi_{ref}(o_{i,t}|q, o_{i,<t})}{\pi_\theta(o_{i,t}|q, o_{i,<t})} - \log \frac{\pi_{ref}(o_{i,t}|q, o_{i,<t})}{\pi_\theta(o_{i,t}|q, o_{i,<t})} - 1

ハイパーパラメータ設定

DeepSeekMathでの設定例は以下の通りです。

  • グループサイズ G = 64
  • KL正則化係数 \beta = 0.04
  • クリッピング範囲 \epsilon = 0.2

DPOとGRPOの比較分析

理論的観点からの比較

最適化の対象

各手法は異なる最適化アプローチを採用しています。

DPOは人間の好みから直接方策を最適化するという革新的なアプローチを取ります。報酬モデルを介さずに、好みデータから直接最適な方策を導出できるため、中間表現による情報の損失を防ぐことができます。

一方、GRPOは報酬信号を使用しますが、価値ネットワークを排除するという工夫をします。グループ内の相対的な比較により、価値関数の推定を統計的に代替することで、学習の安定性と効率性を両立させています。

理論的基盤

理論的な基盤も大きく異なります。

DPOは逆強化学習の理論に基づいており、KL制約下での最適方策が解析的に求められるという強力な理論的保証があります。これにより、収束性や最適性についての理論的な分析が可能になります。

GRPOは統計的推定理論に基づいており、グループ統計による分散削減を活用します。大数の法則と中心極限定理により、グループサイズの増加に伴って推定精度が向上することが保証されています。

実装上の比較

必要なデータ

各手法で必要となるデータの種類は大きく異なります。

DPOでは人間の好みペア(y_w, y_l)が必要です。つまり、同じプロンプトに対する2つの応答を人間が比較し、どちらが良いかを判断したデータが必要となります。このような比較データの収集は手間がかかりますが、人間の判断基準を直接的に反映できるという利点があります。

一方、GRPOでは報酬信号が必要です。これは既存の報酬モデルから得ることができるため、新たなデータ収集は不要です。ただし、報酬モデルの品質がGRPOの性能に直接影響するため、高品質な報酬モデルの存在が前提となります。

計算コスト

3つの手法の計算コストには明確な違いがあります。

DPOは最も軽量な手法です。単一の方策ネットワークのみを更新すればよく、各更新ステップでの計算も単純な順伝播と逆伝播のみです。報酬モデルの計算も不要なため、全体的な計算時間が大幅に短縮されます。

GRPOは中程度の計算コストとなります。複数の応答をサンプリングする必要があるため、推論の計算コストが増加します。しかし、価値ネットワークの学習が不要なため、PPOよりは効率的です。グループサイズが大きいほど統計的に安定しますが、その分計算コストも増加するというトレードオフがあります。

PPOは最も計算コストが高い手法です。方策と価値の両ネットワークを学習する必要があり、さらに各更新で複数のエポックにわたる学習をするため、計算時間が長くなります。大規模言語モデルでは、この差は特に顕著になります。

メモリ使用量

メモリ使用量の観点でも、3つの手法には大きな差があります。

DPOは最もメモリ効率的な手法です。方策ネットワークのみを保持すればよく、追加のネットワークや複雑なバッファリングは不要です。これは特に大規模言語モデルにおいて重要な利点となります。

GRPOは中程度のメモリ使用量となります。価値ネットワークは不要ですが、グループベースのアプローチのため、複数の応答をメモリ上にバッファリングする必要があります。グループサイズが64の場合、64個の応答とその報酬を保持する必要があり、これが追加のメモリ使用につながります。

PPOは最もメモリを消費する手法です。方策ネットワークと価値ネットワークの両方を保持する必要があり、さらに学習時には過去の軌道データもバッファリングする必要があります。大規模言語モデルでは、この追加メモリが大きなボトルネックとなることがあります。

まとめ

各手法の位置づけ

PPOは理論的に確立された堅実な手法として、多くの実用システムで使用されています。複雑性は高いものの、幅広いタスクで安定した性能を提供します。TRPOの理論的基盤を実用的なレベルに落とし込んだ功績は大きく、現在でも多くのシステムで採用されています。

DPOは実装の簡素性と理論的エレガンスを両立した革新的手法です。報酬モデルを排除することで、RLHFパイプラインを大幅に簡素化しました。人間の好みデータから直接学習するというアプローチは、今後の強化学習の方向性を示唆しています。

GRPOは価値ネットワークの課題を統計的アプローチで解決した実用的手法です。特に数値的評価が可能なタスクで高い効果を示しています。グループベースの統計を活用するという発想は、他の強化学習問題にも応用可能な汎用的なアイデアです。

次回予告:CISPO

次回の最終回では、これらの発展をさらに推し進めたCISPO(Clipped Importance Sampling Policy Optimization) について解説します。CISPOは重要度サンプリングの新しいアプローチにより、さらなる効率化と安定性の向上を実現した最新手法です。

Discussion