📖

PPO(Proximal Policy Optimization):強化学習コース(11/N)

に公開

はじめに

前回の記事では、LLMと強化学習の融合について学びました。テキスト生成をマルコフ決定過程として定式化し、RLHF(Reinforcement Learning from Human Feedback)により人間の価値観を学習する仕組みを理解しました。また、KL正則化がTrust Region Policy Optimization(TRPO)の考え方に基づいていることも確認しました。

本記事では、RLHFで最も広く使用されている**PPO(Proximal Policy Optimization)**アルゴリズムについて詳しく解説します。TRPOの理論的基盤から始め、PPOがどのようにその複雑さを実用的なレベルまで簡素化したかを見ていきます。

これまでのシリーズで学んだ内容は以下の通りです。

  • 第1回 - 強化学習の全体像と問題設定
  • 第2回 - 環境側の定式化(マルコフ決定過程)
  • 第3回 - エージェント側の概念とプランニングアルゴリズム(価値反復法)
  • 第4回 - モデルフリー学習とモンテカルロ法
  • 第5回 - 時間差分学習とTD誤差
  • 第6回 - 多段階TD学習とEligibility Trace
  • 第7回 - 関数近似による価値関数学習
  • 第8回 - 方策勾配法の基礎理論
  • 第9回 - Actor-Critic法
  • 第10回 - LLMと強化学習の融合
  • 第11回(本記事)- PPO(Proximal Policy Optimization)

方策勾配法の課題

REINFORCEの復習

第8回で学んだREINFORCE法を振り返ってみましょう。方策勾配定理により、目的関数の勾配は以下のように表されます。

\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^{T} A^{\pi_\theta}(s_t, a_t) \nabla_\theta \log \pi_\theta(a_t|s_t) \right]

ここで、A^{\pi_\theta}(s_t, a_t)はAdvantage関数、\tauは軌道(trajectory)を表します。

ステップサイズの問題

REINFORCEの最大の問題は、適切な学習率の設定が困難という点です。

小さすぎる学習率:学習が非常に遅く、実用的な時間内で収束しません。

大きすぎる学習率:方策が急激に変化し、性能が破滅的に悪化する可能性があります。特に、一度悪い方策に更新されると、そこから回復することが困難になります。

この問題は、方策空間の非線形性に起因します。パラメータ空間での小さな変化が、方策空間では大きな変化を引き起こす可能性があるのです。

Trust Region Policy Optimization (TRPO)

問題設定:方策改善の定式化

TRPOを理解するには、まず「方策をどう改善するか」という問題を正確に定式化する必要があります。

方策の性能指標

方策\piの性能は、期待累積報酬によって定義されます。

J(\pi) = \mathbb{E}_{s_0 \sim \rho_0, \tau \sim \pi} \left[ \sum_{t=0}^{\infty} \gamma^t r_t \right]

ここで、\rho_0は初期状態分布です。また、\tauは方策\piに従って生成される軌道です。

価値関数の復習

第3回で学んだように、方策\piに対して以下の価値関数が定義されます。

状態価値関数は以下のように定義されます。

V^{\pi}(s) = \mathbb{E}_{\tau \sim \pi} \left[ \sum_{t=0}^{\infty} \gamma^t r_t \mid s_0 = s \right]

行動価値関数は以下のように定義されます。

Q^{\pi}(s,a) = \mathbb{E}_{\tau \sim \pi} \left[ \sum_{t=0}^{\infty} \gamma^t r_t \mid s_0 = s, a_0 = a \right]

Advantage関数(第8回で学習)は以下のように定義されます。

A^{\pi}(s,a) = Q^{\pi}(s,a) - V^{\pi}(s)

これは「状態sで行動aを取ることが、平均的な行動と比べてどれだけ良いか」を表します。

方策改善の基本定理の導出

2つの方策\pi_{old}\pi_{new}の性能差を導出します。

まず、\pi_{new}の性能を状態価値関数で表現します。

J(\pi_{new}) = \mathbb{E}_{s_0 \sim \rho_0}[V^{\pi_{new}}(s_0)]

ここで、V^{\pi_{new}}に対するベルマン方程式を使います。

V^{\pi_{new}}(s) = \mathbb{E}_{a \sim \pi_{new}(\cdot|s)}[Q^{\pi_{new}}(s,a)]

Q^{\pi_{new}}(s,a)もベルマン方程式で展開できます。

Q^{\pi_{new}}(s,a) = r(s,a) + \gamma \mathbb{E}_{s' \sim P(\cdot|s,a)}[V^{\pi_{new}}(s')]

ここで巧妙なトリックを使います。任意の関数Vに対して以下の恒等式が成り立ちます。

Q^{\pi_{new}}(s,a) = r(s,a) + \gamma \mathbb{E}_{s'}[V^{\pi_{new}}(s')] = r(s,a) + \gamma \mathbb{E}_{s'}[V(s')] + \gamma \mathbb{E}_{s'}[V^{\pi_{new}}(s') - V(s')]

V = V^{\pi_{old}}を選ぶと以下のようになります。

Q^{\pi_{new}}(s,a) = Q^{\pi_{old}}(s,a) + \gamma \mathbb{E}_{s'}[V^{\pi_{new}}(s') - V^{\pi_{old}}(s')]

これをV^{\pi_{new}}(s)の式に代入すると以下のようになります。

V^{\pi_{new}}(s) = \mathbb{E}_{a \sim \pi_{new}}[Q^{\pi_{old}}(s,a)] + \gamma \mathbb{E}_{a \sim \pi_{new}, s'}[V^{\pi_{new}}(s') - V^{\pi_{old}}(s')]

V^{\pi_{old}}(s) = \mathbb{E}_{a \sim \pi_{old}}[Q^{\pi_{old}}(s,a)]を使って整理すると以下のようになります。

V^{\pi_{new}}(s) - V^{\pi_{old}}(s) = \mathbb{E}_{a \sim \pi_{new}}[A^{\pi_{old}}(s,a)] + \gamma \mathbb{E}_{a \sim \pi_{new}, s'}[V^{\pi_{new}}(s') - V^{\pi_{old}}(s')]

この再帰的な式を展開していくと、定常状態分布\rho_{\pi_{new}}を使って以下のようになります。

J(\pi_{new}) - J(\pi_{old}) = \mathbb{E}_{s \sim \rho_{\pi_{new}}} \left[ \mathbb{E}_{a \sim \pi_{new}(\cdot|s)} [A^{\pi_{old}}(s,a)] \right]

これが方策改善の基本定理です。

実装上の困難

しかし、この式を直接最適化するには問題があります。\rho_{\pi_{new}}(新しい方策の状態分布)が未知であり、\pi_{new}が変わるたびに状態分布も変わってしまいます。

Trust Regionの基本概念

TRPOは、この問題を「局所近似」と「信頼領域」によって解決します。

局所近似による代理目的関数

方策が大きく変わらない場合、状態分布の変化を無視できると仮定します。つまり以下のような近似が可能です。

\rho_{\pi_{new}} \approx \rho_{\pi_{old}} \quad \text{(方策が近い場合)}

この近似の下で、性能改善量は以下のようになります。

J(\pi_{new}) - J(\pi_{old}) \approx \mathbb{E}_{s \sim \rho_{\pi_{old}}} \left[ \mathbb{E}_{a \sim \pi_{new}(\cdot|s)} [A^{\pi_{old}}(s,a)] \right]

さらに、重要度サンプリングを使って\pi_{old}のデータで書き換えると以下のようになります。

L(\pi_{new}) = \mathbb{E}_{s \sim \rho_{\pi_{old}}} \left[ \mathbb{E}_{a \sim \pi_{old}(\cdot|s)} \left[ \frac{\pi_{new}(a|s)}{\pi_{old}(a|s)} A^{\pi_{old}}(s,a) \right] \right]

これが代理目的関数です。J(\pi_{old})は定数なので、L(\pi_{new})を最大化することはJ(\pi_{new})を最大化することと(局所的に)等価です。

パラメータ表記にすると以下のようになります。

L(\theta) = \mathbb{E}_{s \sim \rho_{\pi_{\theta_{old}}}, a \sim \pi_{\theta_{old}}} \left[ \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{\theta_{old}}}(s,a) \right]

代理目的関数の利点 として以下が挙げられます。一度収集したデータ(s,a)のペアを異なる\thetaで何度も再利用できます。Advantage関数A^{\pi_{\theta_{old}}}も再計算不要です。複数回の勾配更新を効率的に行えます。

KL制約

ただし、この近似は\pi_\theta\pi_{\theta_{old}}が近い場合にのみ有効です。そこで、TRPOは以下の制約付き最適化問題を解きます。

\max_\theta L(\theta)
\text{subject to: } \mathbb{E}_{s \sim \rho_{\pi_{\theta_{old}}}} [D_{KL}(\pi_{\theta_{old}}(\cdot|s) \| \pi_\theta(\cdot|s))] \leq \delta

この制約により、方策の変化を制限し、代理目的関数の近似が有効な範囲内で最適化を行います。

自然勾配法

パラメータ空間と方策空間の違い

通常の勾配降下法は、パラメータ空間での最急降下方向を使用します。しかし、方策最適化では方策空間での距離が重要です。

例えば、ニューラルネットワークの最終層のパラメータが少し変化しただけで、出力される確率分布は大きく変わる可能性があります。

フィッシャー情報行列

自然勾配法では、フィッシャー情報行列F(\theta)を用いて確率分布間の局所的な幾何構造を考慮します。

F(\theta) = \mathbb{E}_{s,a \sim \pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a|s) \nabla_\theta \log \pi_\theta(a|s)^T \right]

フィッシャー情報行列は、KLダイバージェンスの2次近似と密接に関係しています。パラメータの微小変化\delta\thetaに対して:

D_{KL}(\pi_\theta || \pi_{\theta + \delta\theta}) \approx \frac{1}{2} \delta\theta^T F(\theta) \delta\theta

つまり、フィッシャー情報行列は確率分布の局所的な「曲率」を表し、パラメータ空間での距離を確率分布の変化に対応付けます。

自然勾配

自然勾配は、通常の勾配をフィッシャー情報行列で変換したものです。

\tilde{\nabla}_\theta L(\theta) = F(\theta)^{-1} \nabla_\theta L(\theta)

これにより、方策空間での最急降下方向が得られます。

TRPOの実装上の課題

TRPOは理論的には優れていますが、実装が極めて困難です。

最も大きな問題は、KL制約を満たしながら目的関数を最適化する制約付き最適化問題を解く必要があることです。これは単純な勾配法では解けず、専用の最適化アルゴリズムが必要になります。

さらに、自然勾配法を実装するためにはフィッシャー情報行列F(\theta)の逆行列を計算する必要があります。パラメータ数がn個の場合、フィッシャー情報行列はn \times nの行列となり、その逆行列計算にはO(n^3)の計算量が必要です。現代のニューラルネットワーク、特にLLMは数十億のパラメータを持つため、この計算は現実的ではありません。

また、適切なKL制約の閾値\deltaの設定も困難です。この値は問題に強く依存し、小さすぎると学習が遅くなり、大きすぎると性能が不安定になります。さらに、フィッシャー情報行列が特異に近い場合には数値計算が不安定になるという問題もあります。

これらの理由により、TRPOは理論的な基盤としては重要ですが、実用的なシステムではほとんど使用されていません

PPOの核心的アイデア

クリッピングによる近似

PPOの画期的な発見は、TRPOの複雑な制約付き最適化(自然勾配法やフィッシャー情報行列の計算を含む)を単純なクリッピング操作で近似できるということです。

これにより、TRPOの理論的利点を保持しながら、実装を劇的に簡素化することに成功しました。

PPOの目的関数

L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min(r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t) \right]

ここで

  • r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}:重要度比
  • \epsilon:クリッピングパラメータ(通常0.1〜0.2)
  • A_t:Advantage推定値

クリッピングの直感的理解

クリッピング機構は、Advantageの符号に応じて異なる動作をします。

正のAdvantage(A_t > 0)の場合

  • この行動は良いので、確率を増やしたい
  • しかし、r_t(\theta) > 1 + \epsilonになったら、それ以上の増加を止める
  • 過度な楽観的更新を防ぐ

負のAdvantage(A_t < 0)の場合

  • この行動は悪いので、確率を減らしたい
  • しかし、r_t(\theta) < 1 - \epsilonになったら、それ以上の減少を止める
  • 過度な悲観的更新を防ぐ

数学的な効果

勾配の挙動

PPOの目的関数の勾配を考えてみましょう。

\frac{\partial L^{CLIP}}{\partial \theta} = \begin{cases} \frac{\partial r_t(\theta)}{\partial \theta} A_t & \text{if } 1-\epsilon < r_t(\theta) < 1+\epsilon \\ 0 & \text{otherwise} \end{cases}

重要度比がクリッピング範囲内にある場合は通常の方策勾配として動作します。範囲外では勾配がゼロになります。

TRPOとの関係

PPOのクリッピングは、TRPOのKL制約を近似的に実現しています。

TRPOの制約D_{KL}(\pi_{\theta_{old}} \| \pi_\theta) \leq \delta

PPOの効果\frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} \in [1-\epsilon, 1+\epsilon]

KLダイバージェンスと確率比には関係があります。確率比を制限することで間接的にKLダイバージェンスも制限されます。

PPOの実装詳細

Actor-Critic構造

PPOは通常、Actor-Critic構造で実装されます(第9回で学んだ内容の応用)。

Actor(方策ネットワーク) は状態s_tを入力として行動の確率分布\pi_\theta(a|s_t)を出力します。Critic(価値ネットワーク) は状態s_tを入力として状態価値V(s_t)を出力します。

多くの実装では、ActorとCriticは下位層を共有し、計算効率を向上させています。

Advantage推定

第8回で学んだAdvantage関数を、実際にどう推定するかが重要です。

Generalized Advantage Estimation (GAE)

PPOでは通常、GAE(Generalized Advantage Estimation)を使用してAdvantageを推定します。これは第6回で学んだTD(λ)の考え方をAdvantage関数の推定に応用したものです。

TD(λ)では、1ステップのTD誤差から無限ステップのモンテカルロ推定まで、λパラメータで連続的に補間しました。GAEも同様に、短期的なバイアスと長期的な分散のトレードオフをλで制御します。

\hat{A}_t = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l}

ここで、\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)はTD誤差です。

この式は、TD(λ)のEligibility Traceと同じ構造を持っています。各時点のTD誤差を、(\gamma \lambda)^lで減衰させながら足し合わせることで、以下の特性を実現します。

  • λ = 0のとき\hat{A}_t = \delta_tとなり、1ステップのTD誤差のみを使用(高バイアス、低分散)
  • λ = 1のとき:実際のリターンから価値関数を引いた値に収束(低バイアス、高分散)
  • 0 < λ < 1のとき:バイアスと分散のバランスを調整

LLMでの簡略化

LLMのRLHFでは、エピソード終了時のみに報酬が与えられます。そのため、以下の簡略化がよく使われます。

\hat{A}_t = R - V(s_t)

ここで、Rは文章生成の完了時に得られる報酬モデルからのスコアです。

学習アルゴリズム

PPOの学習手順は以下の通りです。

  1. 現在の方策\pi_{\theta_{old}}を使って、N個の軌道をサンプリング
  2. 各軌道に対してAdvantage \hat{A}_tを計算
  3. Kエポックにわたって以下を繰り返す:
    • ミニバッチをサンプリング
    • PPO目的関数L^{CLIP}を計算
    • 勾配上昇法でパラメータを更新
  4. \theta_{old} \leftarrow \thetaとして、ステップ1に戻る

実装上の工夫

価値関数のクリッピング

Actorだけでなく、Criticの更新にもクリッピングを適用することがあります。

L^V = \max\left[(V_\theta(s_t) - V_{targ})^2, (\text{clip}(V_\theta(s_t), V_{\theta_{old}}(s_t)-\epsilon, V_{\theta_{old}}(s_t)+\epsilon) - V_{targ})^2\right]

これにより、価値関数の推定も安定化されます。

エントロピー正則化

探索を促進するため、エントロピー項を追加することがよくあります。

L = L^{CLIP} - c_1 L^V + c_2 H[\pi_\theta]

ここで、H[\pi_\theta]は方策のエントロピー、c_1, c_2は重み係数です。

LLMへの応用

RLHFにおけるPPO

第10回で学んだRLHFにおいて、PPOは以下のように使用されます。

  1. SFT済みモデルを初期方策\pi_{\theta_{old}}とする
  2. 報酬モデルから報酬を取得
  3. PPOで方策を最適化
  4. KL正則化により、SFTモデルから大きく逸脱しないよう制御

TRPOのKL制約とRLHFのKL正則化の違い

ここで注意すべきは、TRPOで問題となったKL制約と、RLHFで使用されるKL正則化は性質が異なるという点です。

TRPOのKL制約(ハード制約)

  • D_{KL}(\pi_{\theta_{old}} || \pi_\theta) \leq \deltaという厳密な制約
  • 各更新ステップで、現在の方策と更新後の方策の間のKL距離を制限
  • 制約付き最適化問題として解く必要があり、計算が複雑

RLHFのKL正則化(ソフト制約)

  • 報酬関数に-\beta D_{KL}(\pi_\theta || \pi_{SFT})という項を追加
  • SFT(教師あり学習)済みモデルからの逸脱をペナルティとして扱う
  • 通常の最適化問題として解けるため、実装が容易

つまり、PPOはTRPOの「更新ごとのハードなKL制約」を「クリッピング」に置き換えて計算を簡素化しました。一方、RLHFでは「SFTモデルからの逸脱に対するソフトなペナルティ」として別の形でKLダイバージェンスを活用しています。この組み合わせにより、PPOの計算効率性を維持しながら、LLMが完全に逸脱することを防いでいます。

PPOの理論的性質

収束性

PPOは、適切な条件下で局所最適解への収束が保証されています。

PPOの重要な特性として、クリッピング機構により各更新ステップで性能が大幅に悪化することを防ぐ、単調改善の近似的保証があります。TRPOは厳密な単調改善を保証しますが、PPOはクリッピングによってこれを近似的に実現しています。完全な保証ではありませんが、実践的には十分な安定性を提供します。

サンプル効率の面でも、PPOはTRPOと比較してより少ないサンプルで良好な性能を達成できることが経験的に示されています。これは、PPOがより積極的な更新を可能にしながらも、クリッピングによって破滅的な失敗を防いでいるためです。

限界と課題

サンプル効率

PPOはon-policy手法であるため、古いデータを再利用できません。これは、特に大規模なLLMでは計算コストの問題となります。

ハイパーパラメータ感度

クリッピングパラメータ\epsilon、学習率、エポック数などの設定が性能に大きく影響します。

報酬ハッキング

LLMの文脈では、モデルが報酬モデルの弱点を突いて、実際には低品質な出力で高い報酬を得る現象が観察されています。

まとめ

PPOの重要性

PPOは、TRPOの理論的洞察を実用的なアルゴリズムに落とし込んだ画期的な手法です。特にPPOは複雑な制約付き最適化を単純なクリッピングで近似することで、実装の容易さと性能を両立しました。

今後の展望

PPOの成功を受けて、さらなる改良手法が提案されています。

計算効率の面では、より少ない計算量で同等の性能を達成する手法の開発が進んでいます。特に大規模言語モデルの学習において、PPOの計算コストは依然として高く、これを削減する研究が活発に行われています。

理論的な面では、なぜクリッピングがこれほど効果的なのか、より深い理解が求められています。PPOの成功は経験的には明らかですが、その理論的な正当性については完全には解明されていません。この理解が深まることで、さらに優れたアルゴリズムの開発につながると期待されています。

次回予告:最新の方策最適化の手法

次回の記事では、PPOの課題を克服するために開発された最新の手法について解説します。

特に注目すべきは、DPO(Direct Preference Optimization)です。この手法は報酬モデルを介さない直接的な最適化を実現し、RLHFのプロセスを大幅に簡素化しました。また、GRPO(Group Relative Policy Optimization)は価値関数を使わない効率的な学習を可能にし、計算コストをさらに削減しています。

PPOが築いた基盤の上に、どのような革新が生まれているのか、詳しく見ていきましょう。

Discussion