大規模言語モデル入門および大規模言語モデル入門II出てくるRLHFについて、少し数式で詰まった部分があったので、メモがてら数式の意味を記載していこうと思います。
章番号や数式の番号は書籍のものなので、下記書籍を持ってないとなんのこっちゃだと思います。
https://gihyo.jp/book/2023/978-4-297-13633-8
https://gihyo.jp/book/2024/978-4-297-14393-0
対象は、4.5.2項および12.1.1項です。
強化学習の数式
eq(4.4)
P(y∣x,ϕ)=i=1∏Nπϕ(wi∣x,w<i)
この式は、プロンプトxとパラメータϕにおいて、テキストy=w1,w2,...,wNが生成される確率を表している。
πϕ(wi∣x,w<i)は、プロンプトxと生成途中のテキストw<i=W1,w2,...,wi−1がLLMに与えられたときに、wiが出力される確率となる(softmaxを適用した、語彙に対する確率分布のうちの、wiの確率)。
総乗記号∏の慣れてない方のために、書き下しておきます。
πϕ(w1∣x)×πϕ(w2∣x,w1)×πϕ(w3∣x,w1,w2)×...×πϕ(wN∣x,w1,w2,...,wN−1)
eq(4.5)
ϕ^=argmaxϕEx∼DrlEy∼P(y∣x,ϕ)[R(x,y)]
報酬R(x,y)を最大化する方策のパラメータϕ^を求める式です。
-
ϕ^: 最適なパラメータ値
-
argmaxϕ: Ex∼DrEy∼P(y∣x,ϕ)[R(x,y)]を最大にするパラメータϕを求める、ということ
-
Ex∼Drl: データセットDrlからサンプリングされたxに関する期待値
-
Ey∼P(y∣x,ϕ): テキストyが生成される確率P(y∣x,ϕ)に関する期待値
-
R(x,y):eq(4.6)で説明されてる報酬
期待値が2つあるため混乱するかもしれませんが、これは具体的には:
- 外側の期待値Ex∼Drlは、データ分布からサンプリングされた様々な入力xに対する平均を計算します
- 内側の期待値Ey∼P(y∣x,ϕ)は、特定の入力xとパラメータϕが与えられたときに、モデルが生成する可能性のある様々な出力yに対する平均報酬を計算
を表します。
つまり、この式は「データ分布から得られる様々な入力に対して、モデルが生成する出力の期待報酬を最大化するようなパラメータϕを見つける」という最適化問題を表しています。
eq(4.6)
R(x,y)=rθ(x,y)−βlogP(y∣x,ϕinst)P(y∣x,ϕ)
上記にも出てきたR(x,y)の式です。
第1項のrθ(x,y)は報酬モデルの出力(報酬)です。
第2項の分母P(y∣x,ϕinst)は、選考チューニングをする前のLLMがテキストyを出力する確率となっており、パラメータϕinstは学習させずに保存しておきます。このモデルは、参照モデルとよばれます。
第2項の分子P(y∣x,ϕ)は、今回の選考チューニングで学習対象のLLMがテキストyを出力する確率です。このモデルは方策モデルと呼ばれます。
この第2項は正則化項です。
さて、この報酬R(x,y)は、どんな時に高くなり、どんな時に低くなるのでしょうか。
第1項のrθ(x,y)はそのまま存在するので、報酬モデルが出力する報酬rθ(x,y)が高ければ全体の報酬R(x,y)が高くなり、報酬rθ(x,y)が低ければR(x,y)は低くなります(当たり前)。
第2項は少しややこしいですが、β>0であることを考慮するとこんな風に解釈できるのではないでしょうか?こういう時は場合分けするとわかりやすいですね。
まずはP(y∣x,ϕ)>P(y∣x,ϕinst)の時、つまり方策モデルがyを出力する確率が、参照モデルがyを出力する確率より大きい場合です。
log(⋅)が正の値をとるので、全体報酬R(x,y)は低くなります。
これは、参照モデルが出力しにくいテキストyを無理やり出力した場合に報酬を低くすることで、指示チューニングで学習した内容を忘れることを防ぎたい、というわけです。
その逆P(y∣x,ϕ)<P(y∣x,ϕinst)だと全体報酬R(x,y)は高くなり、両方の確率が等しい場合だと正則化項の寄与は0となります。
eq(4.7)、eq(4.8)およびeq(12.3)は、既に説明済み事項を別の書きかたにしただけなので省略します。
eq(4.6)からわかるように、RLHFの学習には3つのモデルを使用します。そのため多量の計算リソースを必要となり、時間もかかります。
そのため、改善された手法としてDPO(Direct Preference Optimization)があります。
eq(4.9)およびeq(4.10)
ϕ^=argmaxϕ(Ex∼DrlEy∼P(y∣x,ϕ)[R(x,y)]+γEx∼Dpt[logP(x∣ϕ)])
ここでは、使われてるデータセットの違いに注意しましょう。
Drlは、選考チューニングで使うデータセットです。また、Dptは、事前学習に使用したデータセットと同様です。
第1項はeq(4.5)と同じです。
第2項は、
Ex∼Dpt[logP(x∣ϕ)]=∣Dpt∣1x∑i∑logπϕ(ui∣ui−K,…,ui−1)
となっていますね。
Ex∼Dpt[⋅]の部分は、xについて期待値を取るので右辺の∣Dpt∣1∑xに対応しています。
残りの∑ilogπϕ(ui∣ui−K,…,ui−1)はeq(3.2: P35)と文字が変わってるだけで同じです。
ui−K,…,ui−1となっているのは、モデルの入力トークンの最大長がKとなっているので、出力トークンuiからさかのぼってK個分しかLLMに入力できないためです。
Discussion