😊

最適制御とLLMのPrompt Tuning、更にその先へ

に公開

はじめに

前回はDSPyのことはじめとして下記の記事を書いた。
本当にさわりだけを紹介し、普通にLLMから応答を得るための書き方がどのようになるかをコードで示した。この下記記事ではおそらく下記のことを体験していただけたと思う。

  • LLMの選択、LLMの設定、応答の呼び出しが分離されており、コードとして見やすい
  • プロンプトを作って直接応答を得ることもできるし、シグネチャだけから応答を得ることもできる
  • シグネチャや処理をクラスとして設計することができる
  • CoTのような典型的な工夫はクラスとして準備されている

https://zenn.dev/cybernetics/articles/f879e10b53c2db

上記までの事実を把握しただけでは、なんだかコードを書くのがちょっと便利になっただけに感じたり、プロンプトが隠蔽されて振る舞いがよくわかんなくなってしまっていたり、そういう印象を持つと思う。

実査にはコードを書くのが便利になっているのも、プロンプトが隠蔽されているのも、Prompt tuningを便利に取り扱うためとひとまず考えてよい。

この記事では、Prompt tuningが何者であるか?を理解してもらったうえで、Prompt tuningの簡単な例を示す。

教師あり学習

話が逸れるが教師あり学習について一旦まとめる。これは機械学習の最も基本的なタスク設定の一つである。入力と出力のペアが与えられ、モデルにペアを大量に与えることでいい感じに関係性を学習してもらう。数式であらわすと下記のような形式となる。

データ集合 \mathcal{D} =\{(x_i,y_i)\}_{i=1}^N に対し、パラメータ \theta を持つモデル f_\theta を学習する典型形は下記のような最適化問題の形式を持つ。

\hat{\theta} = \argmin_{\theta}\; \frac{1}{N}\sum_{i=1}^N \mathcal{L}\!\big(f_\theta(x_i),\, y_i\big)

ここで \mathcal{L} は損失関数(分類なら交差エントロピー、回帰ならに二乗誤差など)となる。重要なのはモデル f_\theta に対して、データ集合 D を使って \theta をチューニングしているということである。

システム同定

更に話が逸れるがシステム同定という問題がある。モデルは現在の状態と外部からの入力を受け取り、次の状態が決まるような時系列性を有する。数式だと x_{t+1} = f_\theta(x_t,u_t) という様相だ。システム同定とは制御の分野で使われる言葉であるが、要するに時系列モデルの教師あり学習だと思ってよい。

データ集合 \mathcal{D} =\{x_t,u_t,x_{t+1})\}_{t=1}^T を準備すれば形式的には下記のような最適化問題を書いてみることができる。

\hat{\theta} = \argmin_{\theta}\; \frac{1}{T}\sum_{i=1}^{T} \mathcal{L}\!\big(f_\theta(x_t,u_t),\, x_{t+1}\big)

上記の問題を解いて \theta を固定した暁には、モデルに外部から u_t を与えたときに x_t から x_{t+1} に変化するであろうということが予測可能になる。実際の制御の分野では上記の問題を色々工夫して解くため、これを勾配法で殴るという単純な形式では教科書には載っていなかったりするが、概要としては十分だ。また、時系列のデータ集合を何個も何個もたくさん準備していろんな遷移パターンを準備しておくのも重要なことだ。この話をしたのは、言語モデルの事前学習と大いに関連しているためである。

最適制御

次に最適制御という問題を考える。これはドメインナレッジかシステム同定によって得られて時系列モデル x_{t+1} = f_\theta(x_t,u_t) を所与として扱う。

最適制御問題では、理想的なモデルの状態列 \mathbf = x_{1:T+1} とモデル f_\theta から、モデルの状態を理想通りに遷移させるための入力列 \mathbf = u_{1:T} を求めるという試みをする。

\hat{\mathbf u} = \argmin_{\mathbf u}\; \frac{1}{T}\sum_{i=1}^{T} \mathcal{L}\!\big(f_\theta(x_t,u_t),\, x_{t+1}\big)

システム同定の式と変わったとは、モデル f_{theta} が所与となった代わりに \mathbf u が最適かしたいパラメータになったという点である。また、実際には状態の方も最適化パラメータとして動かす定式化もできる。

(\hat{\mathbf u}, \hat{\mathbf x}) = \argmin_{\mathbf u \mathbf x}\; \frac{1}{T}\sum_{i=1}^{T} \mathcal{L}\!\big(f_\theta(x_t,u_t),\, x_{t+1}\big)

この場合にはモデル f が等式制約としてふるまう。状態と入力を好き勝手動かしながら最適化するけど x_{t+1} = f_\theta (x_t, u_t) の関係性は守ってねということである。数学的には解は同じだろうが、数値的に解くときにはいろいろな工夫の違いが出たりする。

ここで、時系列が非常に長い場合は最適化したいパラメータが \mathbf u = (u_1, u_2, ..., u_{100000000}, ...) などと増えていってしまうかもしれない。そうなった場合には、例えば後半の方は適当な一定値を与えてモデルが適当に時間発展するかもしれないことを許容してパラメータを減らすことができる。極端な話、u_1,u_2 のみをパラメータとして動かし、それ以降の時刻を u_2 で固定してしまうのだ。

(\hat u_1, \hat u_2, \mathbf x) = \argmin_{u_1, u_2, \mathbf x}\; \frac{1}{T}\sum_{i=1}^{T} \mathcal{L}\!\big(f_\theta(x_t,u_t),\, x_{t+1}\big)

最適化問題は非常に簡便になる一方で、u_3 以降の振る舞いは狙い通りにならないかもしれない。無論、固定された定数列 u_3, u_4, ... で可能な限り \mathbf x をなぞるように最大限の努力はする。時刻 t = 3 になったら、再度同じ問題を解きなおすことにして入力を考え直せば、まあさほど酷いことにはならないだろうというやり方だ。

要するに \mathbf u, \mathbf x は制御の言葉では制御入力と状態であるが、最適化から見たらそれらが何者であるかは興味が無いので、良しなに求めたいものを最適化パラメータに設定してあげればよい。

Prompt Tuning(プロンプトチューニング)

ここからは、教師あり学習・システム同定・最適制御の流れを踏まえ、Prompt TuningがLLMにおいて何を最適化するのかを明確にする。結論は単純で、LLM本体のパラメータ \theta は凍結し、入力の前置きのみを最適化で調整する。最適制御的な書き方をしてみよう。LLMは自己回帰的な生成器であるから、最も単純化した表現は

x_{t+1} = f_\theta(x_t)

である。初期に与えたプロンプト列 \mathbf{x}=(x_1,\dots,x_K) を条件として x_{K+1} を生成し、以後は自己回帰で続く。

トークン空間と埋め込み空間の区別

記号を整理する。自然言語のトークン空間を X、埋め込み空間(次元 d)を Z とする。トークン列は埋め込み写像(位置埋め込み等を含む)により

\mathrm{emb}:\; X \to Z,\qquad \mathrm{emb}(\mathbf{x})=\mathbf{z}\in Z^{L\times d}

へ写される。自然言語自体を最適制御的に最適化するのはえげつない組み合わせ最適化になってしまいそうなので、連続空間で埋め込んでおく。埋め込んだ先で連続最適化として取り扱えば最適化には勾配法が使えることになるだろう。以降は自然言語のトークン列を連続空間に埋め込んだ前提で議論を進める。

プロンプトの分解と最適化変数

さて、LLMに投げるプロンプトはたいていの場合、共通で利用可能な前置き部分と、個別の問いの部分に分けることができる。

\mathbf{x}=(\mathbf{x}_{sys},\,\mathbf{x}_{usr}),\qquad \mathbf{x}_{sys}=(x_1,\dots,x_k),\quad \mathbf{x}_{usr}=(x_{k+1},\dots,x_K)

これを埋め込みに写すと

\mathbf{z}_{sys}=\mathrm{emb}(\mathbf{x}_{sys})\in Z^{k\times d},\qquad \mathbf{z}_{usr}=\mathrm{emb}(\mathbf{x}_{usr})\in Z^{(K-k)\times d}

となる。そういう空間でデータを扱っているということだけが分かればよい。大抵はLLMの事前学習の段階で埋め込む方法自体も作り終えてある。重要なのは、自然言語で \mathbf{x}_{sys} という共通で利用可能な前置き部分を作るというのをやめてしまう発想だ。これをわざわざ人間が自然言語で関gは得るのは面倒である。この役割を最初から連続ベクトル列 p に置換する。すなわち

p\in Z^{m\times d}

を導入し、モデルへの入力は個別の問いの部分 \mathbf x_{usr} を埋め込んだ \mathbf {z}_{usr}に対して、その前置きとして p を連結し

[p;\,\mathbf{z}_{usr}] \in Z^{(m+K-k)\times d}

として与える。ここで、p のみを、以降最適化変数として扱ってしまおうということである。最適化問題は下記のような形式で書かれると思ってよい。

\hat p = \argmin_{p}\; \frac{1}{N}\sum_{i=1}^{N} \mathcal{L}\!\big(f_\theta([p;\,\mathbf{z}_{usr}^{(i)}]),\,x_i\big)

である。要するに埋め込まれた前置きのプロンプト p を上手に決めてあげることで、ユーザーの様々な問いかけ \mathbf z_{usr} = \mathrm{emb}(\mathbf x_{usr}) に対応できるようになることを期待するのだ。

我々がやりたかったのは問いかけに対して正しい自己回帰による推論を繰り返し、正しい回答をすることである。これに対して、モデル自体をチューニング \theta の最適化はコストが重すぎる。モデル自体はある程度妥当なダイナミクスを内部に有していると信ずるならば、最適制御的に外側で上手く帳尻を合わしたいのである。Prompt tuningの場合は初期状態に相当する何かを調整してあげることでうまくいかないか?ということである。

結局、最適制御という言葉を使ったのは実は冗長だったらしい。だが、最適制御問題の一般系から見たとき、そのサブセットとしてプロンプトチューニング問題をとらえることはできそうだ。しらんけど。

Prompt Tuning の外側

ここまで来たら妄言である。制御入力としてのLLMから見た外生情報をどう定式化するかを考えてみたい。

以上までで、前置き p\in Z^{m\times d} を学習する Prompt Tuning を定式化した。ここからは、最適制御のアナロジーをさらに押し広げたい。最適制御自体には各時刻に制御入力を入れ込む(外生的な情報を付与する)ことができる。というか普通は、その外生的な入力を好き勝手に選べるとしたら、どのような系列にすることでダイナミクスを制御できるかを考えるのが最適制御である(実際は不等式制約などがあり、本当に好き勝手そうさできるわけではないが)。

同じようにLLMというダイナミクスを何らかの方法で制御したいというのは当然である。例えば、**温度パラメータや top-p といったデコーディング方策、RAG による外部コンテキスト、Tool Use(関数呼び出し)**などを「制御入力」に割り当てることができないかを考えてみたいのである。

拡張状態と制御入力

自己回帰生成の内部状態(隠れ状態)を h_t、これまでの生成列を w_{1:t} とし、前置きは学習済みの p とする。外生情報を次のように束ねる。

  • デコーディング方策\phi_t \in \Phi(温度、top-p、長さペナルティ等)
  • 外部コンテキストc_t \in Z^{\ell_t\times d}(RAG の取得結果等の埋め込み列)
  • ツール出力a_t(関数呼び出しの返り値、計算結果、API 応答)

これらをまとめて 制御入力 u_t := (\phi_t, c_t, a_t) とおく。生成は

w_{t+1}\sim P_\theta\!\big(\cdot \,\big|\, w_{1:t},\,[p;\,\mathbf{z}_{usr};\,c_{1:t}],\,\phi_t,\,a_{1:t}\big)

のように書ける。ここで [\,\cdot\,;\,\cdot\,] は埋め込み空間 Z における連結である(前節と同一の前提)。こうすれば、解けるかは分からないけど、デコーデイング方法やRAG情報の入れ方、ツールの出力の使い方などを最適化パラメータとして扱える。もしかしたらツールを使うか否かみたいな離散変数にしたり、RAGの検索範囲を指定したりするかもしれない。

最適制御でも時間を掛けて良い問題は、混合整数計画問題の様相を呈することはある。ただ、これらを完全に自由に動ける自由なパラメータにするとどう考えても現実的ではない気がしないでもない。

制御ポリシーの導入

u_t をその場しのぎで決めるのではなく、履歴に依存するポリシーで決めるのはどうだろうか。実はむしろ制御工学ではこちらが通常の思考である。最適制御はかなり力技で制御方法を決めているが、普通の制御工学では、都度現在の状態(ここで言うとLLMの出力)を監視しながら、それに応じて次に入れる制御入力(例えばツールを使うかや温度を調整するなど)を決定する関数を構える。その関数を通常はポリシーと呼ぶ。

u_t = \kappa_\psi\!\big(w_{1:t},\,p,\,\text{retrieval index},\,\text{tool states}\big)

ここで \psi は制御側のパラメータである。たとえば

  • 取得方策(RAG):c_t = r_\psi(w_{1:t})
  • ツール呼び出し方策:a_t = g_\gamma(w_{1:t})
  • デコーディング方策スケジューリング:\phi_t = s_\eta(t, w_{1:t})

のように分解可能である。これらは 開ループ(時間で決め打ち) でも 閉ループ(履歴で更新) でもよい。何はともわれ、勝手に自由に動けたパラメータLLMのプロンプトに応じて決定される形式となり、パラメータが \psi, \gamma, \eta に埋め込まれることとなった。

目的関数

データ集合 \mathcal{D}=\{(\mathbf{x}_{usr}^{(i)},\,y_i)\}_{i=1}^N に対し、展開過程全体の損失を評価し、前置き p と制御ポリシー \psi,\gamma,\eta を同時に(あるいは交互に)最適化する。そういう発想が力技ではあるが、おそらく最も原始的な考え方だと思う。

\min_{p,\psi,\gamma,\eta}\;\frac{1}{N}\sum_{i=1}^N \underbrace{\mathcal{L}\big(\operatorname{rollout}_\theta(\mathbf{z}_{usr}^{(i)},\,p;\,\kappa_{\psi,\gamma,\eta}),\,y_i\big)}_{\text{応答品質}} \;+\;\lambda_p\|p\|^2\;+\;\Omega(\psi,\gamma,\eta)

\Omega は取得コストやツール利用回数、待ち時間、トークン長などの実運用コストを罰則化する項である。最適制御の言葉で言えば、「性能」と「制御コスト」の合成目的である。こんなものを解くことは(ポリシーの大きさによるが)、ファインチューニングより辛いかもしれないし、問いを大量に準備しなければポリシーが汎化する気もしないので怪しいが、おおざっぱに考える分にはよさそうだ。

MPC 的運用(再ceding horizon)

長い推論を一気に最適化する代わりに、数トークン先だけ見通して更新する運用も有効である。すなわち、

  1. 現時点の履歴で u_t を決める(RAG / Tool / 方策)。
  2. 数ステップ生成して評価。
  3. 次の時点で再び 1 に戻る。

これは最適制御の MPC(Model Predictive Control) に対応し、RAG の逐次取得や CoT の分岐にも自然に適用できるかもしれない。

さいごに

自分の機械学習や制御の知見からLLMにアプローチしていくのは僕自身の理解を深める上でも、そしてモチベーションを保つためにも重要なことらしいことが分かった。ここまで読んでいる人がいたら相当物好きだけど、何か関連論文とか知っていたら是非とも教えてほしい。LLMの勉強は更に進めていきたい。

ツイッターやってるので連絡もらえると助かります。

https://x.com/ML_deep

Discussion