出力品質が下がらないLLM推論高速化手法「投機的デコーディング」
はじめに
u-hyszkと申します!
本記事では大規模言語モデル(LLM)の推論高速化手法である投機的デコーディング(Speculative Decoding)をご紹介します。
また、投機的デコーディングの出力品質が低下しないというメリットを、アルゴリズムや証明に触れながら説明します。
投機的デコーディングとは?
投機的デコーディングとは、小規模なモデルで「先読み」を行うことで処理を高速化する手法です。
通常、言語モデルでは1トークンずつ推論を行います。しかし、1回の推論ごとに1つのトークンしか生成できないため、テキストの生成速度が遅くなるという課題があります。
一方、投機的デコーディングでは、「ドラフトモデル」と呼ばれる小規模なモデルを使用して高速に複数のトークンを「先読み」します。その後、「ターゲットモデル」と呼ばれる大規模なモデルで並列に出力を検証する(投機的サンプリング)ことで、ターゲットモデル1回の推論につき複数のトークンを同時に生成することが可能となり、高速化を実現します。

参考文献[1]より引用
投機的デコーディングの最大のメリットは、理論的に「出力品質が低下しないこと」が保証されている点です。この特性により、リアルタイム性と出力品質が同時に求められる音声対話などの用途で活用しやすくなります。参考文献[2][3][4]では、出力品質を下げずに以下のようなスピードアップが得られることが報告されています。
| ターゲットモデル | ドラフトモデル | 言語 | タスク | スピードアップ |
|---|---|---|---|---|
| T5-XXL (11B params) | T5-small (77M params) | 英語・ドイツ語 | 英独翻訳 (WMT EnDe) | 3.4倍 |
| T5-XXL (11B params) | T5-small (77M params) | 英語 | 要約 (CCN/DM) | 3.1倍 |
| Chinchilla 70B | Chinchilla 4B | 英語 | 要約 (XSum) | 2.01倍 |
| Chinchilla 70B | Chinchilla 4B | - | コード生成 (HumanEval) | 2.46倍 |
| 日本語言語モデル (409M) | 日本語言語モデル (47M) | 日本語 | 要約 (XLSum) | 2.02倍 |
一方で、デメリットとしては、ドラフトモデルを追加で運用する必要がある点が挙げられます。また、使用するドメインに応じて適切なドラフトモデルを選定するコストがかかるという欠点もあります。しかし、ドラフトモデルで推論を行うための計算環境が整っている場合には、投機的デコーディングは非常に強力な手段となり得ます。

Google検索での使用例; 参考文献[1]より引用
アルゴリズム
以下に、投機的デコーディングの大まかなアルゴリズムをフローチャートで示します。

1. ドラフトモデルによる候補トークン生成
まず、ドラフトモデルを使用して高速に
2. ターゲットモデルでの並列検証
次に、生成された候補トークンを含め、ターゲットモデルで並列に推論を実行し、
3. 候補トークンの受理
各候補トークンについて、以下の確率でその候補トークンを受理します。
ただし、ドラフトモデルがトークン
4. 棄却された場合の処理
候補トークンの一部が棄却された場合、以下の式に基づいて次のトークンを決定します。それ以降の候補トークンは全て棄却され、再度1からやり直します。
ただし、
このプロセスを繰り返し、全ての候補トークンが受理された場合には、ターゲットモデルの最後の出力分布を使用してトークンを生成します。
上記のアルゴリズムで使った数式(A)と(B)が後述する証明の中で重要な役割を果たします。
「出力品質が下がらない」の証明
参考文献[3]の証明の大まかな流れを図に示します。
証明の全体としては、投機的デコーディングによって得られるトークンの確率分布

この証明では、候補トークン

「候補トークン

同様に,「候補トークン

このうち候補トークンがターゲットモデルに受理される基準(数式(A))と,棄却されたトークンの代わりにトークン

これらの基準により、他の全ての事象の確率が一意に決定されます。そして、最終的に得られる確率分布
このように、投機的デコーディングで得られる確率分布と通常のデコーディングで得られる確率分布が一致するよう設計されているため、「出力品質が下がらない」ことが保証されます。
より詳細な証明
場合分け
投機的デコーディングにより、あるトークン
- (1)ドラフトモデルにより
が生成され、かつ\tilde{x}=x がターゲットモデルに受容される\tilde{x} - (2)ドラフトモデルにより
以外のトークンx が生成され、ターゲットモデルが\tilde{x}=x' を拒否して代わりに\tilde{x} を生成するx
の2つの事象に場合分けできる。
したがって、投機的デコーディングによって得られる離散分布
以下、左項と右項に分けて導出する。
左項
(1)の事象について、ドラフトモデルにより
右項
(2)の事象について、
この定義を用いて、
ここで、「
また、左項で導出した式(A)を用いると、
となるので、
ここで前式の総和記号の中身について、
なので、最終的に得られる
である。従って、右項は式(B)を使用することで、
と変形できる。
2つの項の足し算
式(A)、(D)より
すなわち、投機的デコーディングによって得られる離散分布
気になった方へ
Huggingfaceではassistant_modelというオプションで投機的デコーディングをサポートしています。
また、その発展系であるSelf-Speculative Decodingなども実装されているので、気になった方は確認してみてください。
from transformers import AutoModelForCausalLM, AutoTokenizer
prompt = "Alice and Bob"
checkpoint = "EleutherAI/pythia-1.4b-deduped"
assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained(checkpoint)
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
outputs = model.generate(**inputs, assistant_model=assistant_model)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
投機的デコーディングに関する最新の研究をまとめたAwesomeレポジトリがあるので、こちらも確認してみてください。
終わりに
本記事では大規模言語モデル(LLM)の推論高速化手法である投機的デコーディングと、そのアルゴリズム・証明をご紹介しました。
個人的には投機的デコーディングは理解するのが難しいと感じた一方で、日本語での記事はまだ限られていると感じました。
本記事の内容が投機的デコーディングへの納得感を高めることに繋がっていれば幸いです!
ご覧いただきありがとうございました🍰
参考文献
[1]Looking back at speculative decoding [2]Fast Inference from Transformers via Speculative Decoding [3]Accelerating Large Language Model Decoding with Speculative Sampling [4]日本語投機的デコーディングの検討
Discussion