🤖

Group Relative Policy Optimization (GRPO): 大規模言語モデルのための効率的強化学習

に公開

概要

Group Relative Policy Optimization (GRPO)は、2024年にDeepSeekがDeepSeekMathの論文で導入した、大規模言語モデルの強化学習における画期的なアルゴリズムです。従来のProximal Policy Optimization (PPO)が必要とする価値ネットワーク(クリティック)を排除し、グループベースのアドバンテージ推定を採用することで、メモリ要件を大幅に削減しながら優れた性能を実現します。DeepSeekMath-RLでは、GSM8Kで82.9%→88.2%、MATH(競技レベル数学)で46.8%→51.7%という性能向上を達成しました(DeepSeekMath論文)。

1. 強化学習の基礎:なぜPPOが生まれたのか

1.1 ポリシー勾配法の原理

大規模言語モデルの強化学習を理解するため、まずポリシー勾配法の基本原理から説明します。ポリシー\pi_\theta(a|s)をパラメータ\thetaでパラメトリック化されたニューラルネットワークとして表現し、期待累積報酬を最大化することが目標です:

J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau)]

ここで\tau = (s_0, a_0, s_1, a_1, \ldots)は軌跡、R(\tau) = \sum_{t=0}^T r_tは累積報酬です。

REINFORCE(Williams, 1992)は最も基本的なポリシー勾配アルゴリズムで、次の勾配推定を使用します:

\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau) \nabla_\theta \log \pi_\theta(\tau)]

これをエピソード内の各ステップに分解すると:

\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}\left[\sum_{t=0}^T \left(\sum_{t'=t}^T r_{t'}\right) \nabla_\theta \log \pi_\theta(a_t|s_t)\right]

1.2 分散削減の必要性とベースライン

REINFORCEの根本的な問題は高い分散です。報酬の大きさが大きく変動すると、勾配推定の分散も大きくなり、学習が不安定になります。この問題を解決するため、ベースラインb(s_t)を導入します:

\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}\left[\sum_{t=0}^T \left(\sum_{t'=t}^T r_{t'} - b(s_t)\right) \nabla_\theta \log \pi_\theta(a_t|s_t)\right]

ベースラインは勾配の期待値を変えませんが(\mathbb{E}[b(s_t) \nabla_\theta \log \pi_\theta(a_t|s_t)] = 0)、分散を大幅に削減します。最適なベースラインは状態価値関数V^\pi(s_t) = \mathbb{E}[R_t|s_t]です。

1.3 アドバンテージ関数の導入

アドバンテージ関数は、特定の行動が平均的な行動よりもどれだけ良いかを測定します:

A^\pi(s_t, a_t) = Q^\pi(s_t, a_t) - V^\pi(s_t)

ここで:

  • Q^\pi(s_t, a_t) = \mathbb{E}[R_t|s_t, a_t]:行動価値関数
  • V^\pi(s_t) = \mathbb{E}[R_t|s_t]:状態価値関数

アドバンテージを使用すると、ポリシー勾配は:

\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}\left[\sum_{t=0}^T A^\pi(s_t, a_t) \nabla_\theta \log \pi_\theta(a_t|s_t)\right]

2. Proximal Policy Optimization (PPO)の詳細解説

2.1 重要度サンプリングとオフポリシー学習

PPOの核心は重要度サンプリング(importance sampling)にあります。古いポリシー\pi_{\theta_{\text{old}}}から収集したデータを使って新しいポリシー\pi_\thetaを更新する際、分布の違いを補正する必要があります:

\mathbb{E}_{a \sim \pi_\theta}[f(a)] = \mathbb{E}_{a \sim \pi_{\theta_{\text{old}}}}\left[\frac{\pi_\theta(a)}{\pi_{\theta_{\text{old}}}(a)} f(a)\right]

重要度比(importance ratio)は:

r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}

これにより、オフポリシー学習の目的関数は:

J^{\text{IS}}(\theta) = \mathbb{E}_{(s,a) \sim \pi_{\theta_{\text{old}}}}\left[r_t(\theta) A^{\pi_{\theta_{\text{old}}}}(s_t, a_t)\right]

2.2 Trust Region問題とクリッピング

重要度サンプリングには根本的な問題があります:\pi_\theta\pi_{\theta_{\text{old}}}の分布が大きく異なると、重要度比が極端に大きくなり、学習が不安定になります。Trust Region Policy Optimization (TRPO)はKLダイバージェンス制約でこれを解決しました:

\max_\theta \mathbb{E}_{(s,a) \sim \pi_{\theta_{\text{old}}}}\left[r_t(\theta) A^{\pi_{\theta_{\text{old}}}}(s_t, a_t)\right]
\text{subject to } \mathbb{E}_{s \sim \pi_{\theta_{\text{old}}}}[\text{KL}[\pi_{\theta_{\text{old}}}(\cdot|s), \pi_\theta(\cdot|s)]] \leq \delta

PPOはこの制約付き最適化問題をクリッピングで近似します:

J^{\text{CLIP}}(\theta) = \mathbb{E}_{(s,a) \sim \pi_{\theta_{\text{old}}}}\left[\min\left(r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t\right)\right]

ここで\text{clip}(x, a, b) = \max(a, \min(x, b))です。

2.3 完全なPPO目的関数

実際のPPOは以下の目的関数を最大化します:

J^{\text{PPO}}(\theta) = \mathbb{E}_{(s,a) \sim \pi_{\theta_{\text{old}}}}\left[\min\left(r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t\right)\right] - \beta \mathbb{E}_{s \sim \pi_{\theta_{\text{old}}}}[\text{KL}[\pi_{\theta_{\text{old}}}(\cdot|s), \pi_\theta(\cdot|s)]]

第二項のKLペナルティ項は追加の安定化機構として機能します。

2.4 アドバンテージ関数の重要性

ここまでPPOの目的関数に登場するA_t(アドバンテージ)について触れてきましたが、実際にこの値をどのように計算するかはまだ説明していません。PPOアルゴリズムを完全に理解し実装するためには、アドバンテージ関数の推定方法を理解することが不可欠です。

理想的には、真のアドバンテージ関数A^{\pi}(s_t, a_t) = Q^{\pi}(s_t, a_t) - V^{\pi}(s_t)を使用したいところですが、実際にはこれらの真の値は未知です。そのため、何らかの方法でアドバンテージを推定する必要があります。

次節では、PPOで標準的に使用されるGeneralized Advantage Estimation (GAE)について詳しく説明します。GAEは、バイアスと分散のトレードオフを調整可能な、洗練されたアドバンテージ推定手法です。

3. Generalized Advantage Estimation (GAE)の詳細

PPOを実装する上で、アドバンテージ関数をどのように推定するかは極めて重要な問題です。Generalized Advantage Estimation (GAE)は、この問題に対する洗練された解決策を提供します。

記号の説明

  • r_t:時刻tで得られる即時報酬(immediate reward)。エージェントが行動a_tを取った直後に環境から受け取る報酬
  • \gamma:割引因子(discount factor)。将来の報酬をどれだけ重視するかを制御するパラメータ(通常0.95〜0.99)
  • V(s_t):状態s_tの価値関数。その状態から期待される累積報酬
  • \lambda:GAEパラメータ。バイアスと分散のトレードオフを制御(通常0.9〜0.99)

3.1 なぜアドバンテージ推定が難しいのか

アドバンテージ関数A^{\pi}(s_t, a_t) = Q^{\pi}(s_t, a_t) - V^{\pi}(s_t)を正確に計算するには、Q^{\pi}V^{\pi}の真の値が必要ですが、これらは通常未知です。単純な推定方法として以下があります:

モンテカルロ推定:エピソード全体を実行して累積報酬を計算

\hat{A}_t^{MC} = R_t - V(s_t)

ここでR_t = \sum_{k=0}^{T-t} \gamma^k r_{t+k}は時刻tからの実際の累積報酬です。r_{t+k}は時刻t+kで得られる即時報酬(immediate reward)を表します。

この方法は不偏推定量ですが、軌跡全体に依存するため分散が非常に大きくなります。

1ステップTD推定:次の状態の価値関数を使用

\hat{A}_t^{TD} = r_t + \gamma V(s_{t+1}) - V(s_t)

ここでr_tは時刻tで得られる即時報酬です。この方法は分散が小さいですが、V(s_{t+1})の推定誤差によりバイアスが生じます

3.2 nステップリターンによる統一的理解

GAEを理解するために、まずnステップリターンを定義します。これは上記の2つの極端な方法の中間的なアプローチです:

R_t^{(n)} = \sum_{k=0}^{n-1} \gamma^k r_{t+k} + \gamma^n V(s_{t+n})

これは「nステップ先まで実際の報酬を使い、その後は価値関数で近似する」という意味です。

  • n=1R_t^{(1)} = r_t + \gamma V(s_{t+1})(1ステップTD、低分散・高バイアス)
  • n=2R_t^{(2)} = r_t + \gamma r_{t+1} + \gamma^2 V(s_{t+2})
  • n=\inftyR_t^{(\infty)} = \sum_{k=0}^{\infty} \gamma^k r_{t+k}(モンテカルロ、高分散・無バイアス)

対応するアドバンテージ推定は:

\hat{A}_t^{(n)} = R_t^{(n)} - V(s_t)

3.3 GAEの核心的アイデア

GAEの天才的な洞察は、「異なるnの推定量を適切に組み合わせることで、バイアスと分散のトレードオフを調整できる」という点です。具体的には、指数的に減衰する重みで各nステップ推定を平均化します:

\hat{A}_t^{\text{GAE}} = (1-\lambda) \sum_{n=1}^{\infty} \lambda^{n-1} \hat{A}_t^{(n)}

ここで\lambda \in [0,1]は新しいハイパーパラメータです。

3.4 時間差分誤差によるエレガントな表現

上記の式は複雑に見えますが、時間差分誤差(TD error)を使うことで、GAEは驚くほどエレガントな形式で表現できます。

時間差分誤差を定義します:

\delta_t^V = r_t + \gamma V(s_{t+1}) - V(s_t)

ここでr_tは時刻tでの即時報酬です。この式は「1ステップ先の予測誤差」を表します。実は、nステップアドバンテージは時間差分誤差の和として表現できることが知られています:

\hat{A}_t^{(1)} = \delta_t^V
\hat{A}_t^{(2)} = \delta_t^V + \gamma \delta_{t+1}^V
\hat{A}_t^{(n)} = \sum_{k=0}^{n-1} \gamma^k \delta_{t+k}^V

この性質を使うと、GAEは次のように書き直せます:

\hat{A}_t^{\text{GAE}(\gamma,\lambda)} = \sum_{l=0}^{\infty} (\gamma\lambda)^l \delta_{t+l}^V

これは以下のように展開できます:

\hat{A}_t^{\text{GAE}(\gamma,\lambda)} = \delta_t^V + \gamma\lambda \delta_{t+1}^V + (\gamma\lambda)^2 \delta_{t+2}^V + \cdots

3.5 GAEパラメータの直感的理解

GAEの2つのパラメータ\gamma\lambdaの役割を理解しましょう:

割引因子\gamma

  • 将来の報酬をどれだけ重視するかを制御
  • 通常0.99など1に近い値を使用
  • 小さくすると短期的な報酬を重視

GAEパラメータ\lambda

  • バイアスと分散のトレードオフを制御
  • \lambda = 0\hat{A}_t = \delta_t^V(1ステップ推定のみ、高バイアス・低分散)
  • \lambda = 1\hat{A}_t = \sum_{l=0}^{\infty} \gamma^l \delta_{t+l}^V = R_t - V(s_t)(モンテカルロ推定、低バイアス・高分散)
  • 実用的には\lambda \in [0.9, 0.99]がよく使用され、バイアスと分散のバランスを取ります

4. PPOの課題とGRPOへの動機

4.1 価値ネットワークの訓練困難性

PPOにおける価値ネットワークV_\psi(s)の訓練は以下の損失関数を最小化します:

L^{VF}(\psi) = \mathbb{E}_{(s,r) \sim \mathcal{D}}\left[\left(V_\psi(s) - R^{\text{target}}\right)^2\right]

ここでR^{\text{target}}はGAEで計算されたターゲット値です。しかし、この訓練には以下の問題があります:

  1. 移動ターゲット問題:ポリシーが変化するとターゲット値も変化し、価値関数の学習が不安定
  2. 過学習:価値ネットワークが特定の状態分布に過度に適合する可能性
  3. 計算コスト:ポリシーネットワークと同等のサイズのネットワークを追加で維持

4.2 LLMにおける特有の課題

大規模言語モデルのコンテキストでは、追加の課題があります:

  1. 系列レベル報酬:通常、報酬は系列の最後でのみ与えられる
  2. 高次元状態空間:各トークンが状態を表し、語彙サイズが大きい
  3. 長系列:数百から数千トークンの長い系列を扱う必要がある

これらの特性により、従来の価値関数推定手法の効果が限定的になります。

5. GRPOの革新的アプローチ

5.1 グループベースアドバンテージの動機

GRPOの核心的洞察は、価値関数を学習する代わりに、同じプロンプトに対する複数の応答の統計を使用してベースラインを計算することです。与えられた質問qに対して、G個の応答をサンプリングします:

\{o_1, o_2, \ldots, o_G\} \sim \pi_{\theta_{\text{old}}}(\cdot|q)

各応答o_iは報酬r_iを受け取ります。グループベースアドバンテージは:

\hat{A}_i = \frac{r_i - \frac{1}{G}\sum_{j=1}^G r_j}{\sqrt{\frac{1}{G-1}\sum_{j=1}^G \left(r_j - \frac{1}{G}\sum_{k=1}^G r_k\right)^2}}

これは統計学の標準化と同じ形式です:\hat{A}_i = \frac{r_i - \mu}{\sigma}

5.2 理論的正当性

このアプローチの理論的基盤は以下にあります:

  1. 大数の法則Gが大きいとき、\frac{1}{G}\sum_{j=1}^G r_j \to \mathbb{E}_{\pi_{\theta_{\text{old}}}}[r(o)]
  2. 分散削減:標準化により各グループ内でゼロ平均、単位分散を保証
  3. 比較的性質:報酬モデルの訓練方法(比較判断)と自然に整合

5.3 GRPO目的関数の導出

GRPOは以下の目的関数を最適化します:

J_{\text{GRPO}}(\theta) = \mathbb{E}_{q \sim \mathcal{D}, \{o_i\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(\cdot|q)} \left[ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \min\left( r_{i,t}(\theta) \hat{A}_{i,t}, \text{clip}(r_{i,t}(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_{i,t} \right) \right] - \beta \mathcal{D}_{\text{KL}}[\pi_\theta \| \pi_{\text{ref}}]

ここで:

  • r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{\text{old}}}(o_{i,t}|q, o_{i,<t})}:トークンレベルの重要度比
  • \hat{A}_{i,t} = \hat{A}_i:系列レベルのアドバンテージを全トークンに適用
  • \beta:KL正則化係数(DeepSeekMathでは0.04を使用)

5.4 KLダイバージェンス正則化の実装

GRPOでは、KLダイバージェンスの推定に以下の不偏推定量を使用します:

\hat{\mathcal{D}}_{\text{KL}}[\pi_\theta \| \pi_{\text{ref}}] = \frac{\pi_{\text{ref}}(o_{i,t}|q, o_{i,<t})}{\pi_\theta(o_{i,t}|q, o_{i,<t})} - \log \frac{\pi_{\text{ref}}(o_{i,t}|q, o_{i,<t})}{\pi_\theta(o_{i,t}|q, o_{i,<t})} - 1

この推定量は常に非負であることが保証されています。

6. 理論的分析と将来の発展

6.1 GRPOの理論的保証

GRPOの収束性に関する理論的分析:

  1. バイアス分析:グループサイズG \to \inftyのとき、グループ平均は真の期待値に収束
  2. 分散分析:標準化により分散を制御、安定した学習を保証
  3. 収束レート:適切な学習率スケジュールの下で、局所最適解への収束を保証

6.2 改良版アルゴリズム

Dr. GRPO(長さバイアス軽減)[1]

\tilde{A}_i = r_i - \frac{1}{G}\sum_{j=1}^G r_j

標準偏差による正規化を除去し、長い間違った答えと短い正しい答えの不公平な比較を防ぎます。

GVPO(Group Variance Policy Optimization)[2]
分散も考慮した目的関数:

J_{\text{GVPO}}(\theta) = \mathbb{E}[r] - \alpha \text{Var}[r]

6.3 統一的理解のためのパラダイム

DeepSeekMathは異なる手法を統一的に理解するフレームワークを提案:

\nabla_\theta J_A(\theta) = \mathbb{E}_{(q,o) \sim D}\left[\frac{1}{|o|} \sum_{t=1}^{|o|} GC_A(q, o, t, \pi_{rf}) \nabla_\theta \log \pi_\theta(o_t|q, o_{<t})\right]

ここでGC_Aは勾配係数(Gradient Coefficient)で、各手法は異なるGC_Aを使用:

手法 データソース 勾配係数
SFT 人間選択データ 1
RFT SFTモデルサンプル \mathbb{I}(\text{correct})
DPO ペア比較データ \sigma(\beta \log \frac{\pi_\theta}{\pi_{ref}})
PPO ポリシーサンプル A_t (GAE)
GRPO グループサンプル \hat{A}_i (グループ相対)

結論

Group Relative Policy Optimization (GRPO)は、大規模言語モデルの強化学習において重要な進歩を表しています。従来のPPOが抱えていた価値ネットワークの複雑性とコストの問題を、グループベースの統計的アプローチで解決しました。DeepSeekMathでの成功実証により、数学的推論をはじめとする様々なタスクでの有効性が確認されています。

GRPOの理論的基盤は統計学と強化学習の自然な融合であり、実装の簡素性と計算効率性を両立しています。今後、より大規模なモデルや複雑なタスクへの適用が期待され、AI研究の民主化に重要な役割を果たすと考えられます。計算リソースの制約下でも高性能な推論モデルを訓練可能にするGRPOは、次世代AI開発の基盤技術として位置づけられるでしょう。

脚注
  1. Dr. GRPOはDeepSeekMathチームが開発した改良版で、詳細はDeepSeekMath論文のAppendix B.3を参照。 ↩︎

  2. GVPOはリスク考慮型強化学習の観点から報酬の分散を最小化する手法。詳細はDeepSeekMath論文のSection 5.3を参照。 ↩︎

Discussion