🐣

生成 AI を私の色で染めたい ~RLHF から DPO へ~

2024/11/15に公開

こんにちは。お行儀にはちょっぴり自信がないひよこです。

AI モデルを人間の好みに合わせて調整する手法は、近年の生成 AI の進歩に欠かせない技術です。まずは、GPT の RLHF (Reinforcement Learning with Human Feedback) によって、大規模言語モデル (LLM) にユーザの選好を反映させることが可能になりました。しかし、強化学習を用いる RLHF は実装が非常に大変で、もっと簡単にモデル調整ができないかという課題がありました。その解決策として登場したのが DPO (Direct Preference Optimization) です。
現在では、DPO の改良手法として IPO (Iterative Preference Optimization)KTO (KL-regularized Preference Optimization) が提案され、ユーザ好みに基づくモデル調整手法として注目を集めています。

今回は、RLHF の概要を簡単に説明し、それに続く 3 つの手法(DPO、IPO、KTO)について解説します。


なぜユーザの選好を教える必要があるの?

そもそも、なぜこのような学習が必要なのでしょうか?例えば、LLM(大規模言語モデル)であればすでに膨大なデータで学習済みです。そのデータには、人々の集合知として長い歴史の中で淘汰されてきた選好や価値観が反映されているはずです。それならば、すでに十分な選好情報がモデルに含まれているのではないでしょうか?

これは知識の吸収という側面では正しい考え方です。しかし、ここにもう一つのややこしい問題が潜んでいました。それは、私たちの社会には「本音と建前」を使い分けないといけないという強烈な不文律があったことです。つまり LLM に膨大な知識を持たせるだけでは不十分であり、TPO(時間・場所・状況)に応じて適切な応答をするための 「お行儀」教育 が不可欠だったということです。生成 AI の黎明期にはこれを怠ったばかりに消えていったサービスがたくさんありました。

この「お行儀」教育を初めて具体的に実現した手法が RLHF (Reinforcement Learning with Human Feedback) でした。RLHF は、人間のフィードバックを通じてモデルに社会的な望ましい応答の形を学習させることで、AI をより「安全で信頼できる存在」にするための重要なステップを踏み出しました。


RLHF (Reinforcement Learning with Human Feedback)

基本概要

RLHF は、強化学習を活用して人間のフィードバックをモデルに反映する手法です。具体的には、選好データを基に報酬モデル R(x) を構築し、その報酬を基に生成モデルを「どのように応答を生成するか」という強化学習における政策(ポリシー)を最適化します。

この手法は、以下の 3 段階で構成されます:

  1. 報酬モデルの訓練

    • 人間の選好データを使用し、入力 x に対する「どれだけ好ましいか」のスコアを予測する報酬モデル R(x) を訓練します。
  2. ポリシーの更新

    • 報酬モデル R(x) を基に、生成モデルのポリシー \pi_\theta を調整します。ここでいうポリシーとは、モデルが「どんな条件で、どのような応答を出すか」を決定するルールです。この段階では、強化学習の一種である PPO (Proximal Policy Optimization) が一般的に用いられます。
  3. ポリシーの最適化

    • 報酬に基づいて生成データの分布をさらに洗練し、モデルがより好まれる出力を生成できるように調整します。

🐣飼い主(報酬モデル)がおやつを使って「こうすると褒めてもらえる」というヒントを与え、ペット(ポリシー)がそれを繰り返して覚えるということだね🐶✨

数式と考え方

RLHF の基本的な流れは、報酬モデル R(x) を学習し、それを基にポリシーを更新することです。PPO を用いたポリシー更新の損失関数は以下のように定義されます:

\mathcal{L}_{\text{PPO}}(\theta) = \mathbb{E}_{x \sim \pi_{\text{old}}} \left[\min\left(r_\theta(x) \cdot A(x), \text{clip}(r_\theta(x), 1-\epsilon, 1+\epsilon) \cdot A(x)\right)\right]

ここで:

  • r_\theta(x) = \frac{\pi_\theta(x)}{\pi_{\text{old}}(x)} は新旧ポリシーの確率比を表します。
  • A(x) はアドバンテージ関数で、特定の行動 x の有利さを示します。
  • \epsilon はクリッピング範囲を制御するハイパーパラメータです。

RLHF の問題点

RLHF を活用する上で、以下の 2 点が大きな課題として挙げられます:

トレードオフの問題

RLHF では、モデルが人間の選好に沿った出力を生成することと、元々の LLM の能力を損なわないこととの間にトレードオフがあります。極端に「お行儀よく」が良くても創造性や賢さが犠牲になってしまえば本末転倒です。

このトレードオフを調整するために、PPO のクリッピング機構が重要な役割を果たします。具体的には、更新後のポリシーと更新前のポリシーの比率を一定範囲内(1-\epsilon1+\epsilon)に制限することで、ポリシーの急激な変更を抑えています。

学習の不安定性と計算コスト

  • RLHF の学習には多くのハイパーパラメータ(クリッピング範囲 \epsilon、Value 関数の学習率、PPO の反復回数など)の調整が必要です。これらが適切でない場合、学習が不安定になったり、収束しないことがあります。

  • ポリシーの更新ステップにおいて報酬モデルの訓練(深層学習)と、ポリシーの更新(強化学習)の 2 段階を踏む必要があり、特に PPO の部分は計算コストが高く、収束までに非常に多くのリソースを消費します。

🐣 LLM の学習だけでも大変なのに、それとは別に強化学習も必要だったなんて絶句…


DPO (Direct Preference Optimization)

RLHF では、PPO によるポリシーの学習がとにかく大変でした。そこで、RLHF における強化学習のアプローチを大胆に省略し、単純な最適化に置き換えた手法が DPO (Direct Preference Optimization) です。

基本概要

DPO は RLHF と同様に「人間の好み」を基にモデルを調整する手法です。例えば、「文章 A と文章 B のどちらが好まれるか?」という選好関係データを基に、モデルがその好みに従った予測を行えるように改善します。

数式と考え方

DPO の目標は、RLHF のようなポリシーではなく、以下のような好みの確率を予測するモデルを作ることです。

P(A \succ B) = \frac{\exp(f(A))}{\exp(f(A)) + \exp(f(B))}
  • f(A):モデルが予測する A のスコア。
  • f(B):モデルが予測する B のスコア。
  • P(A \succ B):A が B より好まれる確率。

これを基に、モデルを訓練するための損失関数は次のように定義されます。

\mathcal{L}(\theta) = -\mathbb{E}_{(x_i, x_j) \sim \mathcal{D}} \left[\log \frac{\pi_\theta(x_i)}{\pi_\theta(x_i) + \pi_\theta(x_j)}\right]

この損失関数は、「モデルの予測が人間の好みデータにどれだけ合っているか」を測定し、間違いが少なくなるようモデルを調整します。RLHF に比べてこの式の最大の利点は、勾配法による最適化が可能なため、強化学習による報酬モデルが不要になる点です。

🐣 f(x)\pi_\theta(x) は同じスコア関数だよ!\theta はモデルの調整パラメータだよ

スコアと確率の計算例(DPO)

好みデータとして「A が B より好まれる」という情報を持つ場合、以下の値を仮定します:

  • f(A) = 2
  • f(B) = 1

このとき、確率 P(A \succ B) を計算します。

  1. \exp(f(A)) = \exp(2) = e^2 \approx 7.39
  2. \exp(f(B)) = \exp(1) = e^1 \approx 2.72

確率は次のようになります:

P(A \succ B) = \frac{\exp(2)}{\exp(2) + \exp(1)} = \frac{7.39}{7.39 + 2.72} \approx 0.73

つまり、「A が B より好まれる確率は 73%」と予測されます。

損失関数の計算例(DPO)

この予測が人間の評価データと一致しているかを損失関数で確認します。

損失関数は以下の式で表されます。

\mathcal{L}(\theta) = -\log(P(A \succ B))

ここで P(A \succ B) = 0.73 を代入すると:

\mathcal{L}(\theta) = -\log(0.73) \approx 0.313

この損失が小さいほど、モデルが好みを正しく予測していることを示します。


DPO の課題:モデル崩壊のリスク

DPO はシンプルでありながら、提案論文では RLHF との等価性が証明されています。しかしながら、この等価性は「無限のデータ」や「完全なモデル」など理想的な条件下に限定されます。実際の運用では、いくつかの課題が残っています。その中でも特に問題となるのが モデル崩壊(mode collapse) です。

モデル崩壊が起こる理由

  • 極端なスコア付け
    損失関数の形状により、モデルが「正しい選択肢」と「誤った選択肢」の間に極端なスコア差をつける傾向があります。その結果、ある選択肢だけが強調されすぎて他を無視する状態になります。

  • データの偏り
    人間の評価データが偏っている場合、モデルがその偏りを過剰に学習し、適切な予測ができなくなることがあります。

DPO では A が B よりも良い、という順序関係だけを学習してしまうため、ちょっとだけ良い、のような微妙な匙加減を表現できなくなってしまっているのです。

モデル崩壊の具体例

例えば、以下の好みデータを考えます:

  • A \succ BA の方が B より好ましい)
  • B \succ CB の方が C より好ましい)

このデータに基づいて DPO を適用すると、モデルが極端なスコアを学習し、次のような状態になることがあります:

  • f(A) = 100, f(B) = 1, f(C) = 0.01

この場合、モデルは A を過剰に評価し、BC をほぼ無視するようになります。この状態では、新しいデータに対する予測が不安定になり、モデル全体の性能が低下します。


IPO (Iterative Preference Optimization)

基本概要

IPO は DPO の考え方を発展させた手法で、「モデルを安定的に改善する」ことを目的としています。その特徴は「正則化」という仕組みを導入している点です。

正則化の役割

IPO では、モデルが極端に変化しないよう、以下の正則化項を損失関数に加えます。二乗誤差なので L2 正則化と呼ばれます。

🐣 絶対値誤差だと L1 正則化ですね

\lambda \|\pi_\theta - \pi_{\text{base}}\|^2
  • \pi_\theta:現在のモデルの出力スコア。
  • \pi_{\text{base}}:以前のモデル(基準となる分布)のスコア。
  • \lambda:正則化の強さを決めるハイパーパラメータ。

ここで、以前のモデルとはお行儀を躾けられる前のモデルのことです。
この正則化項により、モデルの大幅な変化を抑えつつ、安定的に改善することが可能になります。

正則化項の計算例(IPO)

正則化項は、モデルのスコア \pi_\theta が基準スコア \pi_{\text{base}} にどれだけ近いかを測定します。

仮定

  • 現在のスコア \pi_\theta(A) = 0.8, \pi_\theta(B) = 0.2
  • 基準スコア \pi_{\text{base}}(A) = 0.5, \pi_{\text{base}}(B) = 0.5

正則化項は次のように計算されます。

\|\pi_\theta - \pi_{\text{base}}\|^2 = (0.8 - 0.5)^2 + (0.2 - 0.5)^2
  1. (0.8 - 0.5)^2 = 0.09
  2. (0.2 - 0.5)^2 = 0.09

よって、

\|\pi_\theta - \pi_{\text{base}}\|^2 = 0.09 + 0.09 = 0.18

この値を損失関数に追加することで、モデルのスコアが極端になりすぎるのを防ぎます。


KTO (KL-regularized Preference Optimization)

基本概要

KTO は、IPO をさらに発展させた手法で、「モデルの分布が基準の分布に近づくように」調整するため KL 情報量 に基づく正則化を導入します。

KL 情報量とは?

KL 情報量(Kullback–Leibler divergence, KL)は、2 つの分布がどれだけ異なるかを測る指標で、次式で表されます。

D_{\text{KL}}(P \| Q) = \sum_x P(x) \log \frac{P(x)}{Q(x)}

この値は、PQ の分布が完全に一致している場合は 0 になり、分布が離れるほど大きくなります。

KL の計算例(KTO)

モデルの分布 \pi_\theta と基準の分布 \pi_{\text{base}} の違いを KL 情報量で評価します。

仮定

  • \pi_\theta(A) = 0.7, \pi_\theta(B) = 0.3
  • \pi_{\text{base}}(A) = 0.5, \pi_{\text{base}}(B) = 0.5

このとき、KL 情報量は次の式で計算されます。

D_{\text{KL}}(\pi_\theta \| \pi_{\text{base}}) = \pi_\theta(A) \log \frac{\pi_\theta(A)}{\pi_{\text{base}}(A)} + \pi_\theta(B) \log \frac{\pi_\theta(B)}{\pi_{\text{base}}(B)}
  1. \pi_\theta(A) \log \frac{\pi_\theta(A)}{\pi_{\text{base}}(A)} = 0.7 \log \frac{0.7}{0.5} \approx 0.7 \cdot 0.356 = 0.249
  2. \pi_\theta(B) \log \frac{\pi_\theta(B)}{\pi_{\text{base}}(B)} = 0.3 \log \frac{0.3}{0.5} \approx 0.3 \cdot (-0.511) = -0.153

よって、

D_{\text{KL}}(\pi_\theta \| \pi_{\text{base}}) \approx 0.249 - 0.153 = 0.096

この値を損失関数に加えることで、モデルの分布が急激に変化するのを防ぎます。


実は…

RLHF では、PPO のクリッピング機構に加えて、報酬関数に KL 項によるペナルティを含めることで、ポリシーの急激な変化を抑制していました。DPO では、RLHF からこれらの仕組みをバッサリ削ぎ落としてしまい、その結果としてモデル崩壊(mode collapse)を招いたわけです。つまり

  • IPO は、削ぎ落とされた クリッピング の要素を再導入
  • KTO は、削ぎ落とされた KL 正則化 の要素を再導入

した手法だ、とみなすことも可能です。そういった意味では RLHF はよく考えられていたというわけですね。

🐣 DPO はダイエットし過ぎてフラフラだったというわけか


まとめ

RLHF, DPO、IPO、KTO、はいずれも「人間の好みのデータ」を基にモデルを調整する手法ですが、それぞれ異なる特徴を持っていることがわかりました。RLHF から引き算しすぎた DPO から引かれすぎた要素をもう一度足すことで IPO や KTO が提案された流れは面白いですよね。

一般に広く公開される LLM では倫理感や社会的秩序のような堅い理念に基づいて教育されてしまうイメージですが、もちろんローカル LLM であれば趣味全開の嗜好を入れ込むことも可能です。これらの手法を上手に使って自分色に染め上げた LLM を作ってみるのも楽しそうですよね。

🐣 学習に使った選好データセットの流出にだけは気を付けましょうw

Discussion