Zenn
🍰

出力品質が下がらないLLM推論高速化手法「投機的デコーディング」

2024/12/23に公開

https://qiita.com/advent-calendar/2024/ca-26th

はじめに

u-hyszkと申します!

本記事では大規模言語モデル(LLM)の推論高速化手法である投機的デコーディング(Speculative Decoding)をご紹介します。
また、投機的デコーディングの出力品質が低下しないというメリットを、アルゴリズムや証明に触れながら説明します。

投機的デコーディングとは?

投機的デコーディングとは、小規模なモデルで「先読み」を行うことで処理を高速化する手法です。

通常、言語モデルでは1トークンずつ推論を行います。しかし、1回の推論ごとに1つのトークンしか生成できないため、テキストの生成速度が遅くなるという課題があります。

一方、投機的デコーディングでは、「ドラフトモデル」と呼ばれる小規模なモデルを使用して高速に複数のトークンを「先読み」します。その後、「ターゲットモデル」と呼ばれる大規模なモデルで並列に出力を検証する(投機的サンプリング)ことで、ターゲットモデル1回の推論につき複数のトークンを同時に生成することが可能となり、高速化を実現します。

投機的デコーディング
参考文献[1]より引用

投機的デコーディングの最大のメリットは、理論的に「出力品質が低下しないこと」が保証されている点です。この特性により、リアルタイム性と出力品質が同時に求められる音声対話などの用途で活用しやすくなります。参考文献[2][3][4]では、出力品質を下げずに以下のようなスピードアップが得られることが報告されています。

ターゲットモデル ドラフトモデル 言語 タスク スピードアップ
T5-XXL (11B params) T5-small (77M params) 英語・ドイツ語 英独翻訳 (WMT EnDe) 3.4倍
T5-XXL (11B params) T5-small (77M params) 英語 要約 (CCN/DM) 3.1倍
Chinchilla 70B Chinchilla 4B 英語 要約 (XSum) 2.01倍
Chinchilla 70B Chinchilla 4B - コード生成 (HumanEval) 2.46倍
日本語言語モデル (409M) 日本語言語モデル (47M) 日本語 要約 (XLSum) 2.02倍

一方で、デメリットとしては、ドラフトモデルを追加で運用する必要がある点が挙げられます。また、使用するドメインに応じて適切なドラフトモデルを選定するコストがかかるという欠点もあります。しかし、ドラフトモデルで推論を行うための計算環境が整っている場合には、投機的デコーディングは非常に強力な手段となり得ます。

投機的デコーディングの使用例
Google検索での使用例; 参考文献[1]より引用

アルゴリズム

以下に、投機的デコーディングの大まかなアルゴリズムをフローチャートで示します。

アルゴリズム

1. ドラフトモデルによる候補トークン生成

まず、ドラフトモデルを使用して高速にγ\gamma回の推論を行い、候補となるトークンを生成します。

2. ターゲットモデルでの並列検証

次に、生成された候補トークンを含め、ターゲットモデルで並列に推論を実行し、γ+1\gamma+1トークン分の出力分布を取得します。

3. 候補トークンの受理

各候補トークンについて、以下の確率でその候補トークンを受理します。

min(1,q(x)p(x))(A) \min(1, \frac{q(x)}{p(x)}) \qquad (A)

ただし、ドラフトモデルがトークンxxを出力する確率をp(x)p(x)、ターゲットモデルがトークンxxを出力する確率をq(x)q(x)とします。

4. 棄却された場合の処理

候補トークンの一部が棄却された場合、以下の式に基づいて次のトークンを決定します。それ以降の候補トークンは全て棄却され、再度1からやり直します。

xmax(0,q(x)p(x))xmax(0,q(x)p(x))(B) x \sim \frac{\max{(0, q(x) - p(x))}}{\sum_{x'} \max{(0, q(x') - p(x'))}} \qquad (B)

ただし、\simはサンプリング(その確率分布に従ってランダムにトークンを選ぶこと)を表します。

このプロセスを繰り返し、全ての候補トークンが受理された場合には、ターゲットモデルの最後の出力分布を使用してトークンを生成します。

上記のアルゴリズムで使った数式(A)と(B)が後述する証明の中で重要な役割を果たします

「出力品質が下がらない」の証明

参考文献[3]の証明の大まかな流れを図に示します。

証明の全体としては、投機的デコーディングによって得られるトークンの確率分布P(X=x)\mathbb{P}(X=x)とターゲットモデルが通常のデコーディングで生成するトークンの確率分布q(x)q(x)が同じになることを示します。

証明の流れ1

この証明では、候補トークンx~\tilde{x}が受理された場合と棄却された場合で場合分けを行います。

証明の流れ2

「候補トークンx~\tilde{x}が受理された」という事象は、より正確には「ドラフトモデルにより候補トークンx~\tilde{x}が生成されて,その条件下でx~\tilde{x}がターゲットモデルに受理される」という事象です。

証明の流れ3

同様に,「候補トークンx~\tilde{x}が棄却された」という事象は,「ある候補トークンx~\tilde{x}が棄却され、その条件下で棄却されたトークンの代わりにトークンxxを生成する」という事象です。

証明の流れ4

このうち候補トークンがターゲットモデルに受理される基準(数式(A))と,棄却されたトークンの代わりにトークンxxを生成する基準(数式(B))をアルゴリズムの中で設定していました。

証明の流れ5

これらの基準により、他の全ての事象の確率が一意に決定されます。そして、最終的に得られる確率分布P(X=x)\mathbb{P}(X=x)が、ターゲットモデルの確率分布q(x)q(x)と一致することが証明されます。

このように、投機的デコーディングで得られる確率分布と通常のデコーディングで得られる確率分布が一致するよう設計されているため、「出力品質が下がらない」ことが保証されます。

より詳細な証明

場合分け

投機的デコーディングにより、あるトークンxxが生成されたとき、すなわちX=xX=xのとき、

  • (1)ドラフトモデルによりx~=x\tilde{x}=xが生成され、かつx~\tilde{x}がターゲットモデルに受容される
  • (2)ドラフトモデルによりxx以外のトークンx~=x\tilde{x}=x'が生成され、ターゲットモデルがx~\tilde{x}を拒否して代わりにxxを生成する

の2つの事象に場合分けできる。

したがって、投機的デコーディングによって得られる離散分布P(X=x)\mathbb{P}(X=x)は以下のように表現できる。

P(X=x)=P(x~=xx~が受容される)+P(x~が拒否されるX=x)=P(x~=x)P(x~が受容されるx~=x)+P(x~が拒否される)P(X=xx~が拒否される) \mathbb{P}(X = x) \\ = \mathbb{P}(\tilde{x}=x \cap \tilde{x}が受容される) + \mathbb{P}(\tilde{x}が拒否される \cap X=x) \\ = \mathbb{P}(\tilde{x}=x)\mathbb{P}(\tilde{x}が受容される | \tilde{x}=x) + \mathbb{P}(\tilde{x}が拒否される)\mathbb{P}(X=x|\tilde{x}が拒否される)

以下、左項と右項に分けて導出する。

左項

(1)の事象について、ドラフトモデルによりx~=x\tilde{x}=xが生成され、かつx~\tilde{x}がターゲットモデルに受容される確率を以下のように定義する。

P(x~=xx~が受容される)=P(x~=x)P(x~が受容されるx~=x)=p(x)min(1,q(x)p(x))=min(p(x),q(x))(A) \mathbb{P}(\tilde{x}=x \cap \tilde{x}が受容される) \\ = \mathbb{P}(\tilde{x}=x)\mathbb{P}(\tilde{x}が受容される | \tilde{x}=x) \\ = p(x)\min(1, \frac{q(x)}{p(x)}) \\ = \min(p(x), q(x)) \quad \cdots(A)

右項

(2)の事象について、x~\tilde{x}が拒否された上でX=xX=xとなる条件付き確率を以下のように定義する。

P(X=xx~が拒否される)=max(0,q(x)p(x))xmax(0,q(x)p(x))(B) \mathbb{P}(X=x|\tilde{x}が拒否される) = \frac{\max(0, q(x)-p(x))}{\sum_{x'}\max(0, q(x')-p(x'))} \quad \dots(B)

この定義を用いて、x~\tilde{x}が拒否される確率を導出する。まず、排反事象である「x~\tilde{x}が受容される」を用いて、以下のように変形する。

P(x~が拒否される)=1P(x~が受容される) \mathbb{P}(\tilde{x}が拒否される) = 1 - \mathbb{P}(\tilde{x}が受容される)

ここで、「x~\tilde{x}が受容される確率」は、「任意のトークンxx'がドラフトモデルにより生成され、かつx~\tilde{x}が受容される確率の総和」と解釈できるので(条件付き確率の「条件」にあたる部分の全事象の確率を全て加算する)、

P(x~が受容される)=xP(x=xx~が受容される) \mathbb{P}(\tilde{x}が受容される) = \sum_{x'}\mathbb{P}(x=x'\cap\tilde{x}が受容される)

また、左項で導出した式(A)を用いると、

P(x~が受容される)=xP(x=xx~が受容される)=xmin(p(x),q(x))(C)(A)より \mathbb{P}(\tilde{x}が受容される) \\ = \sum_{x'}\mathbb{P}(x=x'\cap\tilde{x}が受容される) \\ = \sum_{x'}\min(p(x'), q(x')) \quad \cdots (C) \because 式(A)より

となるので、x~\tilde{x}が拒否される確率は

P(x~が拒否される)=1xmin(p(x),q(x))=xq(x)xmin(p(x),q(x))確率の定義xq(x)=1=x{q(x)min(p(x),q(x))} \mathbb{P}(\tilde{x}が拒否される) \\ = 1 - \sum_{x'}\min(p(x'), q(x')) \\ = \sum_{x'}q(x') - \sum_{x'}\min(p(x'), q(x')) \quad \because 確率の定義\sum_{x'}q(x')=1 \\ = \sum_{x'} \{ q(x') - \min(p(x'), q(x')) \}

ここで前式の総和記号の中身について、

q(x)min(p(x),q(x))={q(x)p(x)p(x)<q(x)の場合0p(x)>q(x)の場合=max(0,q(x)p(x)) q(x') - \min(p(x'), q(x')) \\ = \begin{cases} q(x') - p(x') & p(x') < q(x')の場合 \\ 0 & p(x') > q(x')の場合 \end{cases} \\ = max(0, q(x')-p(x'))

なので、最終的に得られるx~\tilde{x}が拒否される確率は

P(x~が拒否される)=xmax(0,q(x)p(x)) \mathbb{P}(\tilde{x}が拒否される) = \sum_{x'}\max(0, q(x')-p(x'))

である。従って、右項は式(B)を使用することで、

P(X=xx~が拒否される)P(x~が拒否される)=max(0,q(x)p(x))xmax(0,q(x)p(x))xmax(0,q(x)p(x))=max(0,q(x)p(x))(D) \mathbb{P}(X=x|\tilde{x}が拒否される)\mathbb{P}(\tilde{x}が拒否される) \\ = \frac{\max(0, q(x)-p(x))}{\sum_{x'}\max(0, q(x')-p(x'))}\sum_{x'}\max(0, q(x')-p(x')) \\ = \max(0, q(x) - p(x)) \quad \cdots (D)

と変形できる。

2つの項の足し算

式(A)、(D)より

P(X=x)=P(x~=x)P(x~が受容されるx~=x)+P(x~が拒否される)P(X=xx~が拒否される)=min(p(x),q(x))+max(0,q(x)p(x))={q(x)p(x)<q(x)の場合q(x)p(x)>q(x)の場合=q(x) \mathbb{P}(X = x) \\ = \mathbb{P}(\tilde{x}=x)\mathbb{P}(\tilde{x}が受容される | \tilde{x}=x) + \mathbb{P}(\tilde{x}が拒否される)\mathbb{P}(X=x|\tilde{x}が拒否される) \\ = \min(p(x), q(x)) + \max(0, q(x) - p(x)) \\ = \begin{cases} q(x) & p(x) < q(x)の場合 \\ q(x) & p(x) > q(x)の場合 \end{cases} \\ = q(x)

すなわち、投機的デコーディングによって得られる離散分布P(X=x)\mathbb{P}(X=x)は、ターゲットモデルの離散分布q(x)q(x)と等価である。\square

気になった方へ

Huggingfaceではassistant_modelというオプションで投機的デコーディングをサポートしています。
また、その発展系であるSelf-Speculative Decodingなども実装されているので、気になった方は確認してみてください。

from transformers import AutoModelForCausalLM, AutoTokenizer

prompt = "Alice and Bob"
checkpoint = "EleutherAI/pythia-1.4b-deduped"
assistant_checkpoint = "EleutherAI/pythia-160m-deduped"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt")

model = AutoModelForCausalLM.from_pretrained(checkpoint)
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
outputs = model.generate(**inputs, assistant_model=assistant_model)
tokenizer.batch_decode(outputs, skip_special_tokens=True)

https://huggingface.co/docs/transformers/main/generation_strategies#speculative-decoding

投機的デコーディングに関する最新の研究をまとめたAwesomeレポジトリがあるので、こちらも確認してみてください。
https://github.com/hemingkx/SpeculativeDecodingPapers

終わりに

本記事では大規模言語モデル(LLM)の推論高速化手法である投機的デコーディングと、そのアルゴリズム・証明をご紹介しました。

個人的には投機的デコーディングは理解するのが難しいと感じた一方で、日本語での記事はまだ限られていると感じました。
本記事の内容が投機的デコーディングへの納得感を高めることに繋がっていれば幸いです!

ご覧いただきありがとうございました🍰

参考文献

[1]Looking back at speculative decoding
https://research.google/blog/looking-back-at-speculative-decoding/
[2]Fast Inference from Transformers via Speculative Decoding
https://openreview.net/pdf?id=C9NEblP8vS
[3]Accelerating Large Language Model Decoding with Speculative Sampling
https://arxiv.org/pdf/2302.01318
[4]日本語投機的デコーディングの検討
https://www.anlp.jp/proceedings/annual_meeting/2024/pdf_dir/P7-19.pdf

Discussion

ログインするとコメントできます