大規模言語モデル入門IIに出てくるDPOについて、メモがてら数式の意味を記載していこうと思います。
章番号や数式の番号は書籍のものなので、下記書籍を持ってないとなんのこっちゃだと思います。
https://gihyo.jp/book/2023/978-4-297-13633-8
https://gihyo.jp/book/2024/978-4-297-14393-0
今回の対象は、12.1.3項です。式変形が多めになります。
DPOの数式
eq(12.5), eq(12.6)の式変形
こちらは式変形のみです。
変形前の数式の意味はこちらに書いてますので、知りたい方はこちらを確認していただければと思います。
\arg \max_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ r(x,y) - \beta \log \frac{\pi_\phi(y|x)}{\pi_{\phi_{\text{ref}}}(y|x)} \right]
ここで、-\frac{1}{\beta}をかけると、[\cdot]の中身の符号が逆になります。これでは等号が成り立たなくなってしまうので、最大化問題を最小化問題(つまり\arg \max \to \arg \min)に変更します。
= \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\pi_{\phi_{\text{ref}}}(y|x)} - \frac{1}{\beta}r(x,y) \right]
次に、\logでくくるために、少し変更していきます。
まずは第2項を\logにしましょう。
= \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\pi_{\phi_{\text{ref}}}(y|x)} - \log \exp \left( \frac{1}{\beta}r(x,y) \right) \right]
そうしたら、\log a - \log b = \log \frac{a}{b}を利用して1つの\logでまとめます。
= \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\pi_{\phi_{\text{ref}}}(y|x) \exp\left(\frac{1}{\beta}r(x,y)\right)} \right]
eq(12.9) ~ eq(12.12)の式変形
こちらも、先ほどのeq(12.6)の式変形の続きですので、丁寧に変形していきましょう。
\arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\pi_{\phi_{\text{ref}}}(y|x) \exp\left(\frac{1}{\beta}r(x,y)\right)} \right]
まず、分母分子に\frac{1}{Z(x)}をかけます。
= \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\frac{1}{Z(x)}\pi_\phi(y|x)}{\frac{1}{Z(x)}\pi_{\phi_{\text{ref}}}(y|x) \exp\left(\frac{1}{\beta}r(x,y)\right)} \right]
そしたら先ほどとは逆で、\log \frac{a}{b} = \log a - \log bを使って分離します。この時、
a = \frac{\pi_\phi(y|x)}{\frac{1}{Z(x)}\pi_{\phi_{\text{ref}}}(y|x) \exp\left(\frac{1}{\beta}r(x,y)\right)}
です。それを適用すると下記のようになります。
= \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\frac{1}{Z(x)}\pi_{\phi_{\text{ref}}}(y|x) \exp\left(\frac{1}{\beta}r(x,y)\right)} - \log Z(x) \right]
そしたら第1項の分母はeq(12.8)の\pi_r(y|x)と全く同じ形ですので、置き換えましょう。
= \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\pi_r(y|x)} - \log Z(x) \right]
ここでZ(x)は\pi_\phi(y|x)には関係ない、つまり\phiの関数ではないので、最適な\phiを求める際には定数とみなすことができます。
よってほんとは少し異なりますが、下記のように変形できます(第2項は無関係なので消せる)
\begin{aligned}
&= \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\pi_r(y|x)} \right] - \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log Z(x) \right] \\
&= \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\pi_r(y|x)} \right]
\end{aligned}
なお、一般的には\arg \minは上記のようには分解できませんので、注意が必要です。今回は定数とみなせたので分解できました。
eq(12.8) → eq(12.13)の導出
まず、eq(12.8)の両辺に\logを適用してみましょう。
\begin{aligned}
\log \pi_r (y|x) &= \log \frac{\pi_{\phi_\text{ref}} (y|x) \exp \left( \frac{1}{\beta} r(x, y) \right) }{Z(x)} \\
&= \log \pi_{\phi_\text{ref}} (y|x) + \log \exp \left( \frac{1}{\beta} r(x, y) \right) - \log Z(x) \\
&= \log \pi_{\phi_\text{ref}} (y|x) + \frac{1}{\beta} r(x, y) - \log Z(x)
\end{aligned}
これを、r(x, y)を左辺に、そのほかを右辺に送ると、(符号に注意)
\begin{aligned}
r(x, y) &= \beta \left( \log \pi_r (y|x) - \log \pi_{\phi_\text{ref}} (y|x) + \log Z(x) \right) \\
&= \beta \log \frac{\pi_r(y|x)}{\pi_{\phi_{\text{ref}}}(y|x)} + \beta \log Z(x)
\end{aligned}
eq(12.15)の導出
こちらは、eq(12.1)にeq(12.14)を代入すると、Z(x)の項がキャンセルされるだけなので省略します。
損失関数を理解する
さて、今までは無味乾燥な式変形をひたすらしてきました。
それもこれも、下記のDPO損失関数を得るためでした。
\mathcal{L}_{\text{DPO}}(\phi) = -\mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}_p} \left[ \log \sigma \left( \beta \log \frac{\pi_\phi(y^+|x)}{\pi_{\phi_{\text{ref}}}(y^+|x)} - \beta \log \frac{\pi_\phi(y^-|x)}{\pi_{\phi_{\text{ref}}}(y^-|x)} \right) \right]
この式はどんな意味があるのでしょうか。
そもそも、損失関数の大元はこちらになっています。
\mathcal{L}_{\text{DPO}}(\phi) = -\mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}_p} \left[ \log p^* (y^1 > y^2 | x) \right]
これは、eq(12.1)とeq(12.2)からもわかると思います。
意味としては、プロンプトxが与えられたとき、応答y^1が応答y^2よりも好まれる確率p^* (y^1 > y^2 | x)を最大にしたい(\mathcal{L}_{\text{DPO}}(\phi)はマイナスをかけてるので最小化したい)というものでした。
なので、大元の式にeq(12.15)を代入すればDPO損失関数が得られるというわけです。
つまるところ、本質はeq(12.2)もeq(12.15)も同じになります。
ただし、eq(12.2)
\mathcal{L}(\theta) = -\mathbb{E}_{(x,y^+,y^-) \sim D_{\text{p}}} \left[ \log \left( \sigma \left( r_\theta(x, y^+) - r_\theta(x, y^-) \right) \right) \right]
では、損失関数に報酬r_\theta(x, y^*)が含まれてますね。これでは、報酬r_\theta(x, y^*)を出力できるモデル(報酬モデル)を通さないと方策モデル学習できないことになりますね。
だから強化学習を使っていたわけです。
それが、DPO損失関数だと、損失関数に報酬がなくなり、代わりに求めたい方策が直接書かれてますね。
学習対象のLLM(方策モデル)の出力が含まれているため、この損失関数を用いて直接最適化できることを意味します。
そのことからこの本ではDPOは、RLHFと同様の訓練を勾配法を用いて直接行えるようにした手法と説明されていますね。
今ならこの意味が少しは理解できると思います。
DPOは過学習しやすい?
損失関数\mathcal{L}_{\text{DPO}}は、どんなときに小さくなるのでしょうか。
それは、 \log \frac{\pi_\phi(y^+|x)}{\pi_{\phi_{\text{ref}}}(y^+|x)}が正の大きい値になり、\log \frac{\pi_\phi(y^-|x)}{\pi_{\phi_{\text{ref}}}(y^-|x)}が負の(絶対値が)大きい値になるときです。
※ なぜなら、シグモイドの中身が大きくなると\sigma(\cdot)は1に近づき、その\logはゼロに近い値になるため
\log \frac{\pi_\phi(y^+|x)}{\pi_{\phi_{\text{ref}}}(y^+|x)}が大きい値を取るには、\frac{\pi_\phi(y^+|x)}{\pi_{\phi_{\text{ref}}}(y^+|x)}が1より大きくなる、つまり\pi_\phi(y^+|x)>\pi_{\phi_{\text{ref}}}(y^+|x)となり、両辺の差が大きければ大きいほどよい、ということになります。
これは、参照モデル\pi_{\phi_{\text{ref}}}が出力しにくい「好まれる出力y^+」でもなんでもお構いなしに、学習対象モデル\pi_\phiがy^+を出力する確率を上げていこう、という趣旨となるので、\pi_\phiに無理やりy^+を出力させるよう学習を進めるためですね。
もっというと\pi_{\phi_{\text{ref}}}が出力しにくいy^+のほうが嬉しい、という、参照モデルからかけ離れた方向に学習させようとしているので、RLHFのときにおこなってた正則が壊れていますね。
次に第2項の\log \frac{\pi_\phi(y^-|x)}{\pi_{\phi_{\text{ref}}}(y^-|x)}を見てみましょう。
先ほどと逆なのですが、\logが負の大きい値となるには、中身\frac{\pi_\phi(y^-|x)}{\pi_{\phi_{\text{ref}}}(y^-|x)}が1より小さい値にならなければなりません。そのためには\pi_\phi(y^-|x)<\pi_{\phi_{\text{ref}}}(y^-|x)となり、この差が大きければ大きいほどよいということになります。
これは、参照モデル\pi_{\phi_{\text{ref}}}が出力しやすかったデータy^-のときに、思いっきり\pi_\phi(y^-|x)の確率を下げちゃおう、というものになります。
人間が好まないデータは絶対に出力しないぞ、という意思を感じますね。このとき、y^-を出力しないように過学習が起こるのです。
DPOに関する説明は以上です。
Discussion