テキスト生成APIサーバのスループットを高めるbatching algorithms
はじめに
テキスト生成モデルをAPIサーバでホストする需要が増えてきている昨今ですが1サーバでできるだけ多くのリクエストをさばくためにはどうすればよいでしょうか?もちろん高速なツールを使うことも重要ですが、それだけでは限界があります。前回の記事ではいくつかのツールを比較しましたが、どのツールでもバッチサイズを上げることで単位時間あたりの処理能力を高めることができるということがわかりました。つまりAPIサーバ側でバッチサイズを大きくする工夫をすることでより多くのリクエストをさばくことが可能になります。
今回の記事ではText Generation InferenceやvLLMなどが採用して注目を集めているContinuous batchingと呼ばれる手法について紹介します。
名称や仕組みなどについてはこれらの解説を参考にしています。
予備知識
Continuous batchingの説明に使われる用語の解説のために、多くのテキスト生成モデルで使われているcausal language modelingについて簡単に説明します。causal language modelingではトークン

iter1ではまず与えられたプロンプトのトークンKV cache と呼びます。iter2*でも同様に計算結果を保存して、iter3でも KV cache から計算できるようにする…というのを繰り返してテキスト生成が進みます。
ここからは最初のiter1を prefill フェーズ、iter2*, iter3を decode フェーズと呼びます。
batching algorithms
予備知識を踏まえて本題の複数のリクエストをバッチ化するアルゴリズムを紹介します。
Static batching (no scheduling)
サーバサイドではなにもしないという意味です。クライアントから送られてきたバッチサイズのままモデルに渡してテキストを生成します。

スループットはクライアントに完全に依存します。クライアントがコントロール下にあり、リクエストのバッチサイズやプロンプトの長さをコントロールできる場合は効率的です。
Static batchingは前回の記事で紹介したように多くのライブラリでサポートされています。
Dynamic batching (request-level scheduling)
クライアントからのリクエストをキューに溜めてサーバサイドでバッチにまとめます。まとめたバッチをモデルに渡してテキストを生成します。

クライアントが送るリクエストのバッチサイズによらずにバッチサイズを大きくできるためスループットの向上が期待できます。ただしバッチ内で生成されるテキストの長さが大きく異なる場合は、一番長い生成が完了するまで次のバッチの処理を始められないため、期待したほどスループットが上がらない場合もあります。
実装上はサーバサイドでリクエストをバッチにまとめたあとは、Static batchingと同様なので綺麗に分離できることがメリットです。
たとえばTriton Inference ServerはDynamic batchingをサポートしていて、バッチにまとめたものをバックエンドに渡すことができます。そのためバックエンドはStatic batchingのみをサポートしている任意のバックエンドと組み合わせて使うことができて便利です。
バックエンドとしてはNVIDIAのFasterTransformerと組み合わせたり、Python backend経由でtransformersやDeepSpeedと組み合わせたりできます。
Continuous batching (iteration-level scheduling)
クライアントからリクエストが来たら prefill フェーズのみを計算しそれぞれの KV cache を保存してキューに追加します。キューにある KV cache をバッチにまとめて decode フェーズを計算して1トークンだけ生成して KV cache を更新しキューに戻します。生成が終了した KV cache はキューに戻さずにクライアントに生成結果を送信します。これを繰り返してテキストを生成します。

1トークン生成するごとにバッチを作り直して新しい KV cache を詰めるのでバッチ内に既に生成が終わったリクエストが残ることがなくなりDynamic batchingの欠点を解決しています。特に生成されるテキストの長さが大きく異なるときにStatic batchingやDynamic batchingと比べてスループットの向上が見込めます。冒頭のanyscaleのブログ記事にはContinuous batchingはStatic batchingと比べてmaximum number of generated tokensが長いほどスループットがベースラインから改善している結果が掲載されています。
実装上は KV cache のマネジメントや1トークンの生成ごとのスケジューリングなどが必要なため、テキスト生成に関する多くの部分がそれぞれのライブラリでContinuous batchingをサポートするために再実装されています。ユーザー目線ではtransformersライブラリにあるようなテキスト生成の際のオプション[1]が一部実装されていなかったり、特定のモデルに対応していなかったりすることに注意が必要です。
Text Generation InferenceやvLLMはContinuous batchingをサポートしています。前述のようにサーバとモデルの推論が密結合しているのでバックエンドを選ぶという概念はありませんが、Text Generation Inferenceはbitsandbytesによるint8での推論やflash attentionなどに対応しているようです。またvLLMはPagedAttention[2]と呼ばれる手法を提案していて、バッチにまとめるときのKV cacheのコピーを不要にしたり、 prefill にはxformersのmemory-efficient attentionを利用したりすることでメモリ消費量を抑えてより大きなバッチサイズを実現しているようです。
まとめ
テキスト生成モデルをホストするAPIサーバのスループットを向上するために、サーバサイドでリクエストをバッチ化するアルゴリズムを紹介しました。Dynamic batchingは入出力が固定長な画像認識のようなモデルに対しては必要十分でしたが、入出力が可変長なテキスト生成では効率を高めにくいものでした。Continuous batchingはcausal language modelingの性質を活かしてテキスト生成に特化していて研究の進歩がうかがえますね。ユースケースに合わせてbatching algorithmsを選択することが重要になりそうです。次回以降ではrinnaのモデルを例にどのようなユースケースでどのbatching algorithmsが有用か確かめていきたいと思います。
前回← DeepSpeed, vLLM, CTranslate2 で rinna 3.6b の生成速度を比較する
次回→ バッチ化対応APIサーバで rinna 3.6b のスループットを実測する
Discussion