Zenn
🦁

選考チューニングのDPOを数式ベースで理解する

2025/03/21に公開

大規模言語モデル入門IIに出てくるDPOについて、メモがてら数式の意味を記載していこうと思います。
章番号や数式の番号は書籍のものなので、下記書籍を持ってないとなんのこっちゃだと思います。

https://gihyo.jp/book/2023/978-4-297-13633-8

https://gihyo.jp/book/2024/978-4-297-14393-0

今回の対象は、12.1.3項です。式変形が多めになります。

DPOの数式

eq(12.5), eq(12.6)の式変形

こちらは式変形のみです。
変形前の数式の意味はこちらに書いてますので、知りたい方はこちらを確認していただければと思います。

argmaxϕExDrlEyπϕ(yx)[r(x,y)βlogπϕ(yx)πϕref(yx)] \arg \max_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ r(x,y) - \beta \log \frac{\pi_\phi(y|x)}{\pi_{\phi_{\text{ref}}}(y|x)} \right]

ここで、1β-\frac{1}{\beta}をかけると、[][\cdot]の中身の符号が逆になります。これでは等号が成り立たなくなってしまうので、最大化問題を最小化問題(つまりargmaxargmin\arg \max \to \arg \min)に変更します。

=argminϕExDrlEyπϕ(yx)[logπϕ(yx)πϕref(yx)1βr(x,y)] = \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\pi_{\phi_{\text{ref}}}(y|x)} - \frac{1}{\beta}r(x,y) \right]

次に、log\logでくくるために、少し変更していきます。
まずは第2項をlog\logにしましょう。

=argminϕExDrlEyπϕ(yx)[logπϕ(yx)πϕref(yx)logexp(1βr(x,y))] = \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\pi_{\phi_{\text{ref}}}(y|x)} - \log \exp \left( \frac{1}{\beta}r(x,y) \right) \right]

そうしたら、logalogb=logab\log a - \log b = \log \frac{a}{b}を利用して1つのlog\logでまとめます。

=argminϕExDrlEyπϕ(yx)[logπϕ(yx)πϕref(yx)exp(1βr(x,y))] = \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\pi_{\phi_{\text{ref}}}(y|x) \exp\left(\frac{1}{\beta}r(x,y)\right)} \right]

eq(12.9) ~ eq(12.12)の式変形

こちらも、先ほどのeq(12.6)の式変形の続きですので、丁寧に変形していきましょう。

argminϕExDrlEyπϕ(yx)[logπϕ(yx)πϕref(yx)exp(1βr(x,y))] \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\pi_{\phi_{\text{ref}}}(y|x) \exp\left(\frac{1}{\beta}r(x,y)\right)} \right]

まず、分母分子に1Z(x)\frac{1}{Z(x)}をかけます。

=argminϕExDrlEyπϕ(yx)[log1Z(x)πϕ(yx)1Z(x)πϕref(yx)exp(1βr(x,y))] = \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\frac{1}{Z(x)}\pi_\phi(y|x)}{\frac{1}{Z(x)}\pi_{\phi_{\text{ref}}}(y|x) \exp\left(\frac{1}{\beta}r(x,y)\right)} \right]

そしたら先ほどとは逆で、logab=logalogb\log \frac{a}{b} = \log a - \log bを使って分離します。この時、

a=πϕ(yx)1Z(x)πϕref(yx)exp(1βr(x,y)) a = \frac{\pi_\phi(y|x)}{\frac{1}{Z(x)}\pi_{\phi_{\text{ref}}}(y|x) \exp\left(\frac{1}{\beta}r(x,y)\right)}
b=Z(x) b = Z(x)

です。それを適用すると下記のようになります。

=argminϕExDrlEyπϕ(yx)[logπϕ(yx)1Z(x)πϕref(yx)exp(1βr(x,y))logZ(x)] = \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\frac{1}{Z(x)}\pi_{\phi_{\text{ref}}}(y|x) \exp\left(\frac{1}{\beta}r(x,y)\right)} - \log Z(x) \right]

そしたら第1項の分母はeq(12.8)のπr(yx)\pi_r(y|x)と全く同じ形ですので、置き換えましょう。

=argminϕExDrlEyπϕ(yx)[logπϕ(yx)πr(yx)logZ(x)] = \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\pi_r(y|x)} - \log Z(x) \right]

ここでZ(x)Z(x)πϕ(yx)\pi_\phi(y|x)には関係ない、つまりϕ\phiの関数ではないので、最適なϕ\phiを求める際には定数とみなすことができます。
よってほんとは少し異なりますが、下記のように変形できます(第2項は無関係なので消せる)

=argminϕExDrlEyπϕ(yx)[logπϕ(yx)πr(yx)]argminϕExDrlEyπϕ(yx)[logZ(x)]=argminϕExDrlEyπϕ(yx)[logπϕ(yx)πr(yx)] \begin{aligned} &= \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\pi_r(y|x)} \right] - \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log Z(x) \right] \\ &= \arg \min_\phi \mathbb{E}_{x \sim \mathcal{D}_{\text{rl}}} \mathbb{E}_{y \sim \pi_\phi(y|x)} \left[ \log \frac{\pi_\phi(y|x)}{\pi_r(y|x)} \right] \end{aligned}

なお、一般的にはargmin\arg \minは上記のようには分解できませんので、注意が必要です。今回は定数とみなせたので分解できました。

eq(12.8) → eq(12.13)の導出

まず、eq(12.8)の両辺にlog\logを適用してみましょう。

logπr(yx)=logπϕref(yx)exp(1βr(x,y))Z(x)=logπϕref(yx)+logexp(1βr(x,y))logZ(x)=logπϕref(yx)+1βr(x,y)logZ(x) \begin{aligned} \log \pi_r (y|x) &= \log \frac{\pi_{\phi_\text{ref}} (y|x) \exp \left( \frac{1}{\beta} r(x, y) \right) }{Z(x)} \\ &= \log \pi_{\phi_\text{ref}} (y|x) + \log \exp \left( \frac{1}{\beta} r(x, y) \right) - \log Z(x) \\ &= \log \pi_{\phi_\text{ref}} (y|x) + \frac{1}{\beta} r(x, y) - \log Z(x) \end{aligned}

これを、r(x,y)r(x, y)を左辺に、そのほかを右辺に送ると、(符号に注意)

r(x,y)=β(logπr(yx)logπϕref(yx)+logZ(x))=βlogπr(yx)πϕref(yx)+βlogZ(x) \begin{aligned} r(x, y) &= \beta \left( \log \pi_r (y|x) - \log \pi_{\phi_\text{ref}} (y|x) + \log Z(x) \right) \\ &= \beta \log \frac{\pi_r(y|x)}{\pi_{\phi_{\text{ref}}}(y|x)} + \beta \log Z(x) \end{aligned}

eq(12.15)の導出

こちらは、eq(12.1)にeq(12.14)を代入すると、Z(x)Z(x)の項がキャンセルされるだけなので省略します。

損失関数を理解する

さて、今までは無味乾燥な式変形をひたすらしてきました。
それもこれも、下記のDPO損失関数を得るためでした。

LDPO(ϕ)=E(x,y+,y)Dp[logσ(βlogπϕ(y+x)πϕref(y+x)βlogπϕ(yx)πϕref(yx))] \mathcal{L}_{\text{DPO}}(\phi) = -\mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}_p} \left[ \log \sigma \left( \beta \log \frac{\pi_\phi(y^+|x)}{\pi_{\phi_{\text{ref}}}(y^+|x)} - \beta \log \frac{\pi_\phi(y^-|x)}{\pi_{\phi_{\text{ref}}}(y^-|x)} \right) \right]

この式はどんな意味があるのでしょうか。

そもそも、損失関数の大元はこちらになっています。

LDPO(ϕ)=E(x,y+,y)Dp[logp(y1>y2x)] \mathcal{L}_{\text{DPO}}(\phi) = -\mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}_p} \left[ \log p^* (y^1 > y^2 | x) \right]

これは、eq(12.1)とeq(12.2)からもわかると思います。
意味としては、プロンプトxxが与えられたとき、応答y1y^1が応答y2y^2よりも好まれる確率p(y1>y2x)p^* (y^1 > y^2 | x)を最大にしたい(LDPO(ϕ)\mathcal{L}_{\text{DPO}}(\phi)はマイナスをかけてるので最小化したい)というものでした。
なので、大元の式にeq(12.15)を代入すればDPO損失関数が得られるというわけです。
つまるところ、本質はeq(12.2)もeq(12.15)も同じになります。

ただし、eq(12.2)

L(θ)=E(x,y+,y)Dp[log(σ(rθ(x,y+)rθ(x,y)))] \mathcal{L}(\theta) = -\mathbb{E}_{(x,y^+,y^-) \sim D_{\text{p}}} \left[ \log \left( \sigma \left( r_\theta(x, y^+) - r_\theta(x, y^-) \right) \right) \right]

では、損失関数に報酬rθ(x,y)r_\theta(x, y^*)が含まれてますね。これでは、報酬rθ(x,y)r_\theta(x, y^*)を出力できるモデル(報酬モデル)を通さないと方策モデル学習できないことになりますね。
だから強化学習を使っていたわけです。

それが、DPO損失関数だと、損失関数に報酬がなくなり、代わりに求めたい方策が直接書かれてますね。
学習対象のLLM(方策モデル)の出力が含まれているため、この損失関数を用いて直接最適化できることを意味します。

そのことからこの本ではDPOは、RLHFと同様の訓練を勾配法を用いて直接行えるようにした手法と説明されていますね。
今ならこの意味が少しは理解できると思います。

DPOは過学習しやすい?

損失関数LDPO\mathcal{L}_{\text{DPO}}は、どんなときに小さくなるのでしょうか。
それは、 logπϕ(y+x)πϕref(y+x)\log \frac{\pi_\phi(y^+|x)}{\pi_{\phi_{\text{ref}}}(y^+|x)}が正の大きい値になり、logπϕ(yx)πϕref(yx)\log \frac{\pi_\phi(y^-|x)}{\pi_{\phi_{\text{ref}}}(y^-|x)}が負の(絶対値が)大きい値になるときです。
※ なぜなら、シグモイドの中身が大きくなるとσ()\sigma(\cdot)は1に近づき、そのlog\logはゼロに近い値になるため

logπϕ(y+x)πϕref(y+x)\log \frac{\pi_\phi(y^+|x)}{\pi_{\phi_{\text{ref}}}(y^+|x)}が大きい値を取るには、πϕ(y+x)πϕref(y+x)\frac{\pi_\phi(y^+|x)}{\pi_{\phi_{\text{ref}}}(y^+|x)}が1より大きくなる、つまりπϕ(y+x)>πϕref(y+x)\pi_\phi(y^+|x)>\pi_{\phi_{\text{ref}}}(y^+|x)となり、両辺の差が大きければ大きいほどよい、ということになります。
これは、参照モデルπϕref\pi_{\phi_{\text{ref}}}が出力しにくい「好まれる出力y+y^+」でもなんでもお構いなしに、学習対象モデルπϕ\pi_\phiy+y^+を出力する確率を上げていこう、という趣旨となるので、πϕ\pi_\phiに無理やりy+y^+を出力させるよう学習を進めるためですね。

もっというとπϕref\pi_{\phi_{\text{ref}}}が出力しにくいy+y^+のほうが嬉しい、という、参照モデルからかけ離れた方向に学習させようとしているので、RLHFのときにおこなってた正則が壊れていますね。

次に第2項のlogπϕ(yx)πϕref(yx)\log \frac{\pi_\phi(y^-|x)}{\pi_{\phi_{\text{ref}}}(y^-|x)}を見てみましょう。
先ほどと逆なのですが、log\logが負の大きい値となるには、中身πϕ(yx)πϕref(yx)\frac{\pi_\phi(y^-|x)}{\pi_{\phi_{\text{ref}}}(y^-|x)}が1より小さい値にならなければなりません。そのためにはπϕ(yx)<πϕref(yx)\pi_\phi(y^-|x)<\pi_{\phi_{\text{ref}}}(y^-|x)となり、この差が大きければ大きいほどよいということになります。
これは、参照モデルπϕref\pi_{\phi_{\text{ref}}}が出力しやすかったデータyy^-のときに、思いっきりπϕ(yx)\pi_\phi(y^-|x)の確率を下げちゃおう、というものになります。
人間が好まないデータは絶対に出力しないぞ、という意思を感じますね。このとき、yy^-を出力しないように過学習が起こるのです。

DPOに関する説明は以上です。

Discussion

ログインするとコメントできます