大規模言語モデル入門IIに出てくるDPOについて、メモがてら数式の意味を記載していこうと思います。
章番号や数式の番号は書籍のものなので、下記書籍を持ってないとなんのこっちゃだと思います。
https://gihyo.jp/book/2023/978-4-297-13633-8
https://gihyo.jp/book/2024/978-4-297-14393-0
今回の対象は、12.1.3項です。式変形が多めになります。
DPOの数式
eq(12.5), eq(12.6)の式変形
こちらは式変形のみです。
変形前の数式の意味はこちらに書いてますので、知りたい方はこちらを確認していただければと思います。
argϕmaxEx∼DrlEy∼πϕ(y∣x)[r(x,y)−βlogπϕref(y∣x)πϕ(y∣x)]
ここで、−β1をかけると、[⋅]の中身の符号が逆になります。これでは等号が成り立たなくなってしまうので、最大化問題を最小化問題(つまりargmax→argmin)に変更します。
=argϕminEx∼DrlEy∼πϕ(y∣x)[logπϕref(y∣x)πϕ(y∣x)−β1r(x,y)]
次に、logでくくるために、少し変更していきます。
まずは第2項をlogにしましょう。
=argϕminEx∼DrlEy∼πϕ(y∣x)[logπϕref(y∣x)πϕ(y∣x)−logexp(β1r(x,y))]
そうしたら、loga−logb=logbaを利用して1つのlogでまとめます。
=argϕminEx∼DrlEy∼πϕ(y∣x)logπϕref(y∣x)exp(β1r(x,y))πϕ(y∣x)
eq(12.9) ~ eq(12.12)の式変形
こちらも、先ほどのeq(12.6)の式変形の続きですので、丁寧に変形していきましょう。
argϕminEx∼DrlEy∼πϕ(y∣x)logπϕref(y∣x)exp(β1r(x,y))πϕ(y∣x)
まず、分母分子にZ(x)1をかけます。
=argϕminEx∼DrlEy∼πϕ(y∣x)logZ(x)1πϕref(y∣x)exp(β1r(x,y))Z(x)1πϕ(y∣x)
そしたら先ほどとは逆で、logba=loga−logbを使って分離します。この時、
a=Z(x)1πϕref(y∣x)exp(β1r(x,y))πϕ(y∣x)
です。それを適用すると下記のようになります。
=argϕminEx∼DrlEy∼πϕ(y∣x)logZ(x)1πϕref(y∣x)exp(β1r(x,y))πϕ(y∣x)−logZ(x)
そしたら第1項の分母はeq(12.8)のπr(y∣x)と全く同じ形ですので、置き換えましょう。
=argϕminEx∼DrlEy∼πϕ(y∣x)[logπr(y∣x)πϕ(y∣x)−logZ(x)]
ここでZ(x)はπϕ(y∣x)には関係ない、つまりϕの関数ではないので、最適なϕを求める際には定数とみなすことができます。
よってほんとは少し異なりますが、下記のように変形できます(第2項は無関係なので消せる)
=argϕminEx∼DrlEy∼πϕ(y∣x)[logπr(y∣x)πϕ(y∣x)]−argϕminEx∼DrlEy∼πϕ(y∣x)[logZ(x)]=argϕminEx∼DrlEy∼πϕ(y∣x)[logπr(y∣x)πϕ(y∣x)]
なお、一般的にはargminは上記のようには分解できませんので、注意が必要です。今回は定数とみなせたので分解できました。
eq(12.8) → eq(12.13)の導出
まず、eq(12.8)の両辺にlogを適用してみましょう。
logπr(y∣x)=logZ(x)πϕref(y∣x)exp(β1r(x,y))=logπϕref(y∣x)+logexp(β1r(x,y))−logZ(x)=logπϕref(y∣x)+β1r(x,y)−logZ(x)
これを、r(x,y)を左辺に、そのほかを右辺に送ると、(符号に注意)
r(x,y)=β(logπr(y∣x)−logπϕref(y∣x)+logZ(x))=βlogπϕref(y∣x)πr(y∣x)+βlogZ(x)
eq(12.15)の導出
こちらは、eq(12.1)にeq(12.14)を代入すると、Z(x)の項がキャンセルされるだけなので省略します。
損失関数を理解する
さて、今までは無味乾燥な式変形をひたすらしてきました。
それもこれも、下記のDPO損失関数を得るためでした。
LDPO(ϕ)=−E(x,y+,y−)∼Dp[logσ(βlogπϕref(y+∣x)πϕ(y+∣x)−βlogπϕref(y−∣x)πϕ(y−∣x))]
この式はどんな意味があるのでしょうか。
そもそも、損失関数の大元はこちらになっています。
LDPO(ϕ)=−E(x,y+,y−)∼Dp[logp∗(y1>y2∣x)]
これは、eq(12.1)とeq(12.2)からもわかると思います。
意味としては、プロンプトxが与えられたとき、応答y1が応答y2よりも好まれる確率p∗(y1>y2∣x)を最大にしたい(LDPO(ϕ)はマイナスをかけてるので最小化したい)というものでした。
なので、大元の式にeq(12.15)を代入すればDPO損失関数が得られるというわけです。
つまるところ、本質はeq(12.2)もeq(12.15)も同じになります。
ただし、eq(12.2)
L(θ)=−E(x,y+,y−)∼Dp[log(σ(rθ(x,y+)−rθ(x,y−)))]
では、損失関数に報酬rθ(x,y∗)が含まれてますね。これでは、報酬rθ(x,y∗)を出力できるモデル(報酬モデル)を通さないと方策モデル学習できないことになりますね。
だから強化学習を使っていたわけです。
それが、DPO損失関数だと、損失関数に報酬がなくなり、代わりに求めたい方策が直接書かれてますね。
学習対象のLLM(方策モデル)の出力が含まれているため、この損失関数を用いて直接最適化できることを意味します。
そのことからこの本ではDPOは、RLHFと同様の訓練を勾配法を用いて直接行えるようにした手法と説明されていますね。
今ならこの意味が少しは理解できると思います。
DPOは過学習しやすい?
損失関数LDPOは、どんなときに小さくなるのでしょうか。
それは、 logπϕref(y+∣x)πϕ(y+∣x)が正の大きい値になり、logπϕref(y−∣x)πϕ(y−∣x)が負の(絶対値が)大きい値になるときです。
※ なぜなら、シグモイドの中身が大きくなるとσ(⋅)は1に近づき、そのlogはゼロに近い値になるため
logπϕref(y+∣x)πϕ(y+∣x)が大きい値を取るには、πϕref(y+∣x)πϕ(y+∣x)が1より大きくなる、つまりπϕ(y+∣x)>πϕref(y+∣x)となり、両辺の差が大きければ大きいほどよい、ということになります。
これは、参照モデルπϕrefが出力しにくい「好まれる出力y+」でもなんでもお構いなしに、学習対象モデルπϕがy+を出力する確率を上げていこう、という趣旨となるので、πϕに無理やりy+を出力させるよう学習を進めるためですね。
もっというとπϕrefが出力しにくいy+のほうが嬉しい、という、参照モデルからかけ離れた方向に学習させようとしているので、RLHFのときにおこなってた正則が壊れていますね。
次に第2項のlogπϕref(y−∣x)πϕ(y−∣x)を見てみましょう。
先ほどと逆なのですが、logが負の大きい値となるには、中身πϕref(y−∣x)πϕ(y−∣x)が1より小さい値にならなければなりません。そのためにはπϕ(y−∣x)<πϕref(y−∣x)となり、この差が大きければ大きいほどよいということになります。
これは、参照モデルπϕrefが出力しやすかったデータy−のときに、思いっきりπϕ(y−∣x)の確率を下げちゃおう、というものになります。
人間が好まないデータは絶対に出力しないぞ、という意思を感じますね。このとき、y−を出力しないように過学習が起こるのです。
DPOに関する説明は以上です。
Discussion