Zenn
🤖

【選考チューニング】あるプロンプトをLLMに与えたときに、ある出力が得られる確率を求める

2025/03/22に公開

選考チューニングのDPOを勉強していて、理論は理解できたのですが、どうしてもπϕ(y1x)\pi_\phi(y^1|x)をどうやってLLMから算出するかがわからないままでした。
そんなこと知らなくても、ライブラリを使えば簡単に学習できるわけですが...ライブラリがあるから理論を知らなくていいということは決してないのでちゃんと調べてみました。

調べたことをメモがてら書いておきます(間違ってる可能性もありますが)。

定義

LLM: πϕ\pi_\phi
入力プロンプト: xx
任意の出力: y=w1,w2,,wNy = w_1, w_2, \dots, w_N
狙った出力: y1=w11,w21,,wN1y^1 = w_1^1, w_2^1, \dots, w_N^1

語彙: VV
語彙数: V|V|
LLMモデルの出力: logitsと呼ぶ(入力のトークン数×V\times|V| のテンソル)
logitをsoftmaxしたもの: Probabilitiesと呼ぶ(入力のトークン数×V\times|V| のテンソル)

本題

プロンプトxxをLLMに与えたときに出力yyを得る確率

パラメータϕ\phiを持つLLMπϕ\pi_\phiにプロンプトxxを渡したときに、出力y=w1,w2,,wNy = w_1, w_2, \dots, w_Nを得られる確率は下記の式で表されます。

P(yx,ϕ)=πϕ(yx)=i=1Nπϕ(wix,w<i) P(y|x, \phi) = \pi_\phi(y|x) = \prod_{i=1}^N \pi_\phi (w_i | x, w_{<i})

これを求めるだけなら、行けそうですね。

まずはπϕ(w1x)\pi_\phi (w_1 | x)の確率を求める。
これは、LLMに入力xxを渡したときの出力(logit)にsoftmax変換をさせたProbabilitiesの最後の要素から確率を求められますね。Python風に書くとProbabilities[-1][出力トークンID]でしょうか。

πϕ(w2x,wq)\pi_\phi (w_2 | x, w_q)の確率を求める。
これは、LLMに入力x,w1x, w_1を結合して渡したときの、Probabilities[-1][出力トークンID]です。

と言うように、順にトークンwiw_iを出力させて、その確率を使えばP(yx,ϕ)P(y|x, \phi)を計算することができますね。

プロンプトxxをLLMに与えたときに、狙った出力y1y^1を得る確率

上記のように任意の出力yyを得る確率P(yx,ϕ)P(y|x, \phi)は簡単に求められることがわかりました。
これはLLMを自然に実行したときに得られるProbabilitiesを使えば簡単に求められました。
では、狙った出力y1y^1を出力する確率P(y1x,ϕ)P(y^1|x, \phi)を得るにはどうしたらよいでしょう。

結論、このように算出しているみたいです。

まず、LLMにx,yx, yを結合させたものを入力し、Probabilitiesを得ます。
言葉でうまく説明できないので、ところどころPythonコード風になりますが、
Probabilities[len(x)-1][y^1の1トークン目のトークンID]
Probabilities[len(x)][y^1の2トークン目のトークンID]
Probabilities[len(x)+1][y^1の3トークン目のトークンID]
...
Probabilities[len(x)+len(y)-2][y^1のlen(y)トークン目のトークンID]
がそれぞれ、πϕ(w11x),πϕ(w21x,w11),πϕ(w31x,w11,w21),,πϕ(wNx,w<N)\pi_\phi (w_1^1 | x), \pi_\phi (w_2^1 | x, w_1^1), \pi_\phi (w_3^1 | x, w_1^1, w_2^1), \dots, \pi_\phi (w_N | x, w_{<N})に対応します。

さて、なぜx,y1x, y^1を結合させたものを入力とした出力が、P(y1x,ϕ)P(y^1|x, \phi)となるのでしょう?
未来の情報が入ってしまって大丈夫でしょうか?

じつは、大丈夫なのです。
LLMにはCausal Attentionという、未来の情報をマスクするAttentionが使われてます。このことから、上記のようにx,y1x, y^1を入力として確率を計算しても問題のです。

調査終わり。

Discussion

ログインするとコメントできます