大規模言語モデル入門および大規模言語モデル入門II出てくるRLHFについて、少し数式で詰まった部分があったので、メモがてら数式の意味を記載していこうと思います。
章番号や数式の番号は書籍のものなので、下記書籍を持ってないとなんのこっちゃだと思います。
https://gihyo.jp/book/2023/978-4-297-13633-8
https://gihyo.jp/book/2024/978-4-297-14393-0
対象は、4.5.1項および12.1.1項です。
報酬モデリングの数式
eq(12.1)
下記の式変形です。
exp(r∗(x,y1))+exp(r∗(x,y2))exp(r∗(x,y1))=σ(r∗(x,y1)−r∗(x,y2))
シグモイド関数は、下記となるので、この形を目指していきます。
1+exp(−x)1
(手書きですみません)

eq(4.2)およびeq(12.2)
多少の表記ゆれ(定義の違い)があるだけで同じものなので、eq(12.2)ベースでみていきます。
下記損失関数(12.2)を最小化するように学習させます。
L(θ)=−E(x,y+,y−)∼Dp[log(σ(rθ(x,y+)−rθ(x,y−)))]
ここで、E[⋅]は期待値、Dpはデータセット、x,y+,y−はデータセットDpから抽出されたプロンプトと好ましい応答と好ましくない応答、rθ(x,y+)は好ましい応答に対する報酬モデルの出力、rθ(x,y−)は好ましくない応答に対する報酬モデルの出力、σ(⋅)はシグモイド関数です。
では、一つ一つ、どんな式なのか理解していきましょう。
rθ(x,y+)−rθ(x,y−)
こちらは簡単ですね。好ましい応答に対する報酬と好ましい応答に対する報酬の差です。
好ましい応答尾の報酬を大きくしたいため、rθ(x,y+)−rθ(x,y−)が正になってほしい、というお気持ちはわかると思います。
これにシグモイド関数を適用するとどうなるでしょう。
-
rθ(x,y+)−rθ(x,y−)が0より大きくなると、σ(⋅)は1に近づく(こちらになることを期待している)
-
rθ(x,y+)−rθ(x,y−)が0より小さくなるとなると、σ(⋅)は0に近づく
さらに、これにlogを適用しているので
-
rθ(x,y+)−rθ(x,y−)が0より大きくなると、log(σ(⋅))は0に近づく
-
rθ(x,y+)−rθ(x,y−)が0より小さくなると、log(σ(⋅))は−∞に近づく
となります。
log(σ(rθ(x,y+)−rθ(x,y−)))の期待値のマイナスをかけたものを最小化したい。
つまり、rθ(x,y+)−rθ(x,y−)が大きい正の数になるように学習させたいことがわかると思います。
このようにしてチューニングされた報酬モデルrθを用いて、LLMが出力する文章に対して報酬rθ(x,y)を返していきます。
Discussion