POD-Attentionの論文を読んで
まだ書きかけですが..
原文
Microsoft Researchとワシントン大学の方が投稿
POD-Attention は注意機構の計算を最大59%(平均28%)高速化し、独立して最適化されたプレフィルおよびデコードのアテンションカーネルを使用する場合と比べて、より高いスループット(処理量)と低いレイテンシ(遅延)で大規模言語モデル(LLM)の推論を可能にします。
事前知識
vllm
sarathi
原文を読んで、まとめた(要約)した記事です。
LLMの推論のフェーズ
LLM(大規模言語モデル)の推論におけるは、次の2つのフェーズを踏む。
- 計算リソースに依存する「プリフィル(prefill)」フェーズ
- メモリ帯域幅に依存する「デコード(decode)」フェーズ
プリフィルのイテレーションの遅延 > デコードのイテレーションの遅延
「プリフィル(prefill)」フェーズ
ユーザーのプロンプトのトークンを並列に処理し、最初の出力トークンを生成。
- 非常に並列化されており、計算リソースに依存。
「デコード(decode)」フェーズ
「プリフィル(prefill)」フェーズの後、各イテレーションで1つの出力トークン(リクエストごと)を自己回帰的に生成します。各出力トークンを生成するのにかかる遅延はTBTと呼ばれる。
- メモリ帯域幅に依存
- デコード(decode)は1回のリクエストにつき1つのトークンを処理するため、クエリ系列長(QSL)次元におけるタイル長は1となる。(Group Query Attentionでは、このタイル長がクエリヘッド数とKVヘッド数の比率まで増加し、通常は2~8の範囲。)
アテンション周りのカーネルAPI
FlashAttention (FA) や FlashInfer (FI) などの最先端のライブラリは、各フェーズごとに最適化された専門的なカーネルAPIを提供
FlashDecoding は、QSL が 1 のデコード処理向けに設計されており、GPU の SM(Streaming Multiprocessors)を完全に使い切るための並列性が不十分な場合には、K/V(キー/バリュー)次元に計算をさらに分割する。
GPU効率化の問題が抱える課題
-
1つのリクエストにおいてプリフィルとデコードは、異なるタイミングで発生する。
-
多くの実世界のLLMアプリケーションでは、コンテキスト長が引き続き増加している。
そのようなシナリオでは、アテンション計算が支配的となり、総推論時間の60%以上を占めることが多くなってきた。
最先端のLLM提供システムでの対策
ハイブリッドバッチ処理(hybrid batching)
- 異なるリクエストのプリフィルフェーズとデコードフェーズの入力を同じバッチにまとめて処理。
- GPUが一度だけモデルの重みを読み込み、それを使ってプリフィル入力とデコード入力の両方に対する計算を行えるようにする。
- スケジューラは長い入力プロンプト(プリフィル入力)を複数の小さなチャンクに分割し、各イテレーションで進行中のデコードと新しいプリフィルチャンクを組み合わせて処理
図 ハイブリッドバッチの計算(引用:https://arxiv.org/pdf/2410.18038)
スケジュールの戦略
- プリフィルとデコード操作の異なる計算特性が、LLM推論におけるスループットとレイテンシのトレードオフを生み出す。
図 スケジュールの戦略によるトレードオフ(引用:https://arxiv.org/pdf/2410.18038)
vLLMスケジューラ は、プリフィル優先のスケジューリングを使用してデコードバッチサイズを最大化。
- 低いTTFT
- 高いTBT
Sarathi-Serveは、チャンク化されたプリフィルと連続的なハイブリッドバッチ処理を提案。この技術は、リクエストのプリフィルトークンを複数の小さなチャンクに分け、進行中のデコードとともに1イテレーションごとに1つのプリフィルチャンクをスケジューリングする。
- 高いTTFT
- 低いTBT
論文が提案するPOD-Attention
プリフィルとデコードのアテンションを同時に計算するGPUカーネル(GPU 上で実行される関数)であり、計算資源とメモリ帯域幅の両方を同時に活用できる。
なぜ既存の技術がアテンション計算の融合において十分なパフォーマンスを発揮しないのか?
遅延したスレッド、同期バリア、およびGPUストリーミングマルチプロセッサ(SM)上で異なる協調スレッド配列(CTA)のSMレベルでの共存保証の欠如がある。
(引用:https://arxiv.org/pdf/2410.18038)
GPUの機構周り
最小の実行単位はスレッド。
32 個のスレッドが集まってワープ(warp)を構成。
ワープ内のスレッドは通常、同時にロックステップ(同期して同じ命令を実行。)
GPU カーネル(GPU 上で実行される関数)を起動する場合、
- 各 CTA に含めるスレッド数、
- カーネル内の CTA の数、
- 各 CTA に必要な共有メモリ量
を指定
カーネルの起動は「ストリーム」にキューされ、ストリーム内の処理は直列(シリアル)に実行される。
異なるストリーム間では順序に関係なく並列実行される可能性がある。
階層レベルの並列実行の方法
- Kernel-parallel
- ストリームは、異なる GPU カーネルを並列に実行できる
- 異なる処理が同じ SM上に配置されることは保証されない。
- CTA-parallel
- カーネル内の CTA(協調スレッド配列)があらかじめ決められた方法で各処理に分割される。
- 異なる処理が同じ SM上に配置されることは保証されない。
- Warp-parallel
- CTA 内のすべてのワープが同じ SM(Streaming Multiprocessor)上に存在することが保証される。
- ストラグラー問題が存在する。
- CTA 全体の実行が完了するまで次の CTA を割り当てることができない。
- Intra-thread
- 各スレッドが異なる処理の命令を交互に実行
POD-Attentionによる対策
POD-Attention は、CTA単位で並列に計算を融合することでこれらの問題に対処し、GPU内部でSMを意識したソフトウェアベースのCTAスケジューリングを導入した。
LLM推論スケジューラである Sarathi-Serve に関する実験結果では、POD-Attention が、FlashAttention と FlashInfer のプリフィルおよびデコードアテンションカーネルよりも最大59%速く(平均28%速く)アテンションを計算することが示された。
POD-Attentionの詳細
プレフィルとデコードのアテンションを効率的に計算する単一の GPU カーネル。
SM-aware CTA スケジューリング
プレフィルとデコードの CTA を共配置する。ここでは、CTA がランタイムでプレフィルまたはデコードを実行するかを決定する際に、次の二つを確認。
- どの SM で起動されたか
- 同じ SM 上で実行中の他の CTA が何をしているか。
プレフィルとデコードに必要な CTA の数を独立して決定し、その合計に一致する CTA 数でカーネルを起動する。
各 SM には、起動された CTA の数を追跡するカウンターに加え、これまでに実行されたプレフィルとデコード CTA の数を追跡する2つのカウンターも備えている。
(引用:https://arxiv.org/pdf/2410.18038)
- 50:50 ポリシー
- SM 上での次の CTA はプレフィルとデコードが交互にスケジュールされる。
- 比例ポリシー
- 現在のバッチ内でのプレフィルとデコードの CTA 比率に基づいて CTA を割り当てる。
パフォーマンス最適化
単にプレフィル(prefill)とデコード(decode)操作を同じ場所で実行するわけでない、
プレフィルとデコードのアテンション計算を融合させる効果を最大化するためのさまざまな最適化手法
がある。
-
タイルサイズ(Tile Sizes)
- QSL(クエリ系列長)に対しては、CUTLASS が A100 のテンソル演算で要求する最小のデコードタイル長16を使用した。
- デコードに大きなタイルサイズを使うことは、融合カーネル(fused kernel)では逆効果。
デコードで発生する冗長な計算が、同じ場所で実行されるプレフィル処理に干渉するから。これは両者がテンソルコアを共有しているため。
-
SMあたりの並行CTA
-
仮想デコードCTA
-
プレフィル分割の制限
Discussion