Zenn
👻

選考チューニングのRLHFを数式ベースで理解する(強化学習編)

2025/03/20に公開

大規模言語モデル入門および大規模言語モデル入門II出てくるRLHFについて、少し数式で詰まった部分があったので、メモがてら数式の意味を記載していこうと思います。
章番号や数式の番号は書籍のものなので、下記書籍を持ってないとなんのこっちゃだと思います。

https://gihyo.jp/book/2023/978-4-297-13633-8

https://gihyo.jp/book/2024/978-4-297-14393-0

対象は、4.5.2項および12.1.1項です。

強化学習の数式

eq(4.4)

P(yx,ϕ)=i=1Nπϕ(wix,w<i) P(y|x, \phi) = \prod_{i=1}^{N} \pi_\phi(w_i|x, w_{<i})

この式は、プロンプトxxとパラメータϕ\phiにおいて、テキストy=w1,w2,...,wNy = w_1, w_2, ..., w_Nが生成される確率を表している。
πϕ(wix,w<i)\pi_\phi(w_i|x, w_{<i})は、プロンプトxxと生成途中のテキストw<i=W1,w2,...,wi1w_{<i} = W_1, w_2, ..., w_{i-1}がLLMに与えられたときに、wiw_iが出力される確率となる(softmaxを適用した、語彙に対する確率分布のうちの、wiw_iの確率)。

総乗記号\prodの慣れてない方のために、書き下しておきます。

πϕ(w1x)×πϕ(w2x,w1)×πϕ(w3x,w1,w2)×...×πϕ(wNx,w1,w2,...,wN1) \pi_\phi(w_1|x) \times \pi_\phi(w_2|x, w_1) \times \pi_\phi(w_3|x, w_1, w_2) \times ... \times \pi_\phi(w_N|x, w_1, w_2, ..., w_{N-1})

eq(4.5)

ϕ^=argmaxϕExDrlEyP(yx,ϕ)[R(x,y)] \hat{\phi} = \text{argmax}_{\phi} \mathbb{E}_{x \sim D_\text{rl}} \mathbb{E}_{y \sim P(y|x,\phi)} [R(x, y)]

報酬R(x,y)R(x, y)を最大化する方策のパラメータϕ^\hat{\phi}を求める式です。

  • ϕ^\hat{\phi}: 最適なパラメータ値
  • argmaxϕ\text{argmax}_{\phi}: ExDrEyP(yx,ϕ)[R(x,y)]\mathbb{E}_{x \sim D_r} \mathbb{E}_{y \sim P(y|x,\phi)} [R(x, y)]を最大にするパラメータϕ\phiを求める、ということ
  • ExDrl\mathbb{E}_{x \sim D_\text{rl}}: データセットDrlD_\text{rl}からサンプリングされたxxに関する期待値
  • EyP(yx,ϕ)\mathbb{E}_{y \sim P(y|x,\phi)}: テキストyyが生成される確率P(yx,ϕ)P(y|x,\phi)に関する期待値
  • R(x,y)R(x, y):eq(4.6)で説明されてる報酬

期待値が2つあるため混乱するかもしれませんが、これは具体的には:

  1. 外側の期待値ExDrl\mathbb{E}_{x \sim D_\text{rl}}は、データ分布からサンプリングされた様々な入力xxに対する平均を計算します
  2. 内側の期待値EyP(yx,ϕ)\mathbb{E}_{y \sim P(y|x,\phi)}は、特定の入力xxとパラメータϕ\phiが与えられたときに、モデルが生成する可能性のある様々な出力yyに対する平均報酬を計算

を表します。
つまり、この式は「データ分布から得られる様々な入力に対して、モデルが生成する出力の期待報酬を最大化するようなパラメータϕ\phiを見つける」という最適化問題を表しています。

eq(4.6)

R(x,y)=rθ(x,y)βlogP(yx,ϕ)P(yx,ϕinst) R(x, y) = r_{\theta}(x, y) - \beta \log \frac{P(y|x, \phi)}{P(y|x, \phi_{\text{inst}})}

上記にも出てきたR(x,y)R(x, y)の式です。
第1項のrθ(x,y)r_{\theta}(x, y)報酬モデルの出力(報酬)です。
第2項の分母P(yx,ϕinst)P(y|x, \phi_{\text{inst}})は、選考チューニングをする前のLLMがテキストyyを出力する確率となっており、パラメータϕinst\phi_{\text{inst}}は学習させずに保存しておきます。このモデルは、参照モデルとよばれます。
第2項の分子P(yx,ϕ)P(y|x, \phi)は、今回の選考チューニングで学習対象のLLMがテキストyyを出力する確率です。このモデルは方策モデルと呼ばれます。
この第2項は正則化項です。

さて、この報酬R(x,y)R(x, y)は、どんな時に高くなり、どんな時に低くなるのでしょうか。

第1項のrθ(x,y)r_{\theta}(x, y)はそのまま存在するので、報酬モデルが出力する報酬rθ(x,y)r_{\theta}(x, y)が高ければ全体の報酬R(x,y)R(x, y)が高くなり、報酬rθ(x,y)r_{\theta}(x, y)が低ければR(x,y)R(x, y)は低くなります(当たり前)。

第2項は少しややこしいですが、β>0\beta > 0であることを考慮するとこんな風に解釈できるのではないでしょうか?こういう時は場合分けするとわかりやすいですね。

まずはP(yx,ϕ)>P(yx,ϕinst)P(y|x, \phi) > P(y|x, \phi_{\text{inst}})の時、つまり方策モデルがyyを出力する確率が、参照モデルがyyを出力する確率より大きい場合です。
log()\log(\cdot)が正の値をとるので、全体報酬R(x,y)R(x, y)は低くなります。
これは、参照モデルが出力しにくいテキストyyを無理やり出力した場合に報酬を低くすることで、指示チューニングで学習した内容を忘れることを防ぎたい、というわけです。

その逆P(yx,ϕ)<P(yx,ϕinst)P(y|x, \phi) < P(y|x, \phi_{\text{inst}})だと全体報酬R(x,y)R(x, y)は高くなり、両方の確率が等しい場合だと正則化項の寄与は0となります。

eq(4.7)、eq(4.8)およびeq(12.3)は、既に説明済み事項を別の書きかたにしただけなので省略します。

eq(4.6)からわかるように、RLHFの学習には3つのモデルを使用します。そのため多量の計算リソースを必要となり、時間もかかります。
そのため、改善された手法としてDPO(Direct Preference Optimization)があります。

eq(4.9)およびeq(4.10)

ϕ^=argmaxϕ(ExDrlEyP(yx,ϕ)[R(x,y)]+γExDpt[logP(xϕ)]) \hat{\phi} = \text{argmax}_{\phi} \left( \mathbb{E}_{x \sim D_\text{rl}} \mathbb{E}_{y \sim P(y|x,\phi)} [R(x, y)] + \gamma \mathbb{E}_{x \sim D_\text{pt}} [ \log P(x|\phi) ] \right)

ここでは、使われてるデータセットの違いに注意しましょう。
DrlD_\text{rl}は、選考チューニングで使うデータセットです。また、DptD_\text{pt}は、事前学習に使用したデータセットと同様です。

第1項はeq(4.5)と同じです。

第2項は、

ExDpt[logP(xϕ)]=1Dptxilogπϕ(uiuiK,,ui1) \mathbb{E}_{x \sim D_{\text{pt}}} [\log P(x|\phi)] = \frac{1}{|D_{\text{pt}}|} \sum_{x}\sum_{i} \log \pi_{\phi}(u_i|u_{i-K}, \ldots, u_{i-1})

となっていますね。
ExDpt[]\mathbb{E}_{x \sim D_{\text{pt}}} [ \cdot ]の部分は、xxについて期待値を取るので右辺の1Dptx\frac{1}{|D_{\text{pt}}|} \sum_{x}に対応しています。
残りのilogπϕ(uiuiK,,ui1)\sum_{i} \log \pi_{\phi}(u_i|u_{i-K}, \ldots, u_{i-1})はeq(3.2: P35)と文字が変わってるだけで同じです。
uiK,,ui1u_{i-K}, \ldots, u_{i-1}となっているのは、モデルの入力トークンの最大長がKKとなっているので、出力トークンuiu_iからさかのぼってKK個分しかLLMに入力できないためです。

Discussion

ログインするとコメントできます