vLLM V1の実装②:EngineCore
前回 (だいぶ前😇) は vLLM V1 の概要と EngineCoreClient について見てきましたが、今回は実際の推論処理を担う EngineCore
の実装を見ていきたいと思います。
vLLM V1 における推論処理の中核を担うのが EngineCore
およびその派生クラス群です。前回の EngineCoreClient の解説ではクライアントサイドのアーキテクチャを見てきましたが、今回は実際の推論オーケストレーションを行うサーバーサイドの実装を詳しく追っていきます。
(参照したコミットは時期がズレているので、最新と比較すると細部または実装方針が異なるかもしれません。ご了承ください)
EngineCore の位置づけとアーキテクチャ概要
vLLM V1 では、推論処理を効率化するために明確な責任分離を実現しています。EngineCore
は「Inner loop of vLLM's Engine」として機能し、実際の推論タスクをオーケストレーションする司令塔の役割を担います。
EngineCore 系クラスは次の3層構造で設計されています。
クラス | 役割 | 特徴 |
---|---|---|
EngineCore |
推論オーケストレーション | Scheduler、ModelExecutor との連携による推論実行 |
EngineCoreProc |
プロセス分離と ZMQ 通信 | バックグラウンドプロセスでの安定動作、ZMQ による非同期通信 |
DPEngineCoreProc |
データ並列処理 | 複数エンジン間のウェーブ同期、分散環境での協調処理 |
EngineCore: 推論オーケストレーションの中核
基本構成と初期化
EngineCore
は vLLM V1 における推論処理の心臓部です。初期化時に主要なコンポーネントを設定し、推論実行の基盤を構築します。
class EngineCore:
"""Inner loop of vLLM's Engine."""
def __init__(self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
executor_fail_callback: Optional[Callable] = None):
# プラグインロード
from vllm.plugins import load_general_plugins
load_general_plugins()
# モデル実行エンジンのセットアップ
self.model_executor = executor_class(vllm_config)
# KVキャッシュの初期化とプロファイリング
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
self._initialize_kv_caches(vllm_config)
# 構造化出力管理
self.structured_output_manager = StructuredOutputManager(vllm_config)
# スケジューラのセットアップ
self.scheduler: SchedulerInterface = Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
structured_output_manager=self.structured_output_manager,
# ... 他の設定
)
# マルチモーダル入力キャッシュサーバ
self.mm_input_cache_server = MultiModalInputCacheServer(
vllm_config.model_config, MULTIMODAL_REGISTRY)
初期化処理では、まず load_general_plugins()
により vLLM のプラグインシステムを起動し、拡張機能を有効化します。次に executor_class
を使用してモデル実行エンジンをセットアップしますが、これは TPU や GPU などのハードウェアに応じた実行戦略を抽象化したものです。
_initialize_kv_caches
メソッドでは、利用可能な GPU メモリを正確に測定し、モデルのアテンション計算で使用される KV キャッシュのブロック数を決定します。このプロファイリング処理により、GPU メモリと CPU メモリの最適な配分を実現し、メモリ効率を最大化しています。
StructuredOutputManager
は JSON Schema やガイド付き生成など、構造化された出力を生成するための管理コンポーネントです。そして Scheduler
は、これらの設定を基に推論リクエストのバッチング戦略や実行順序を制御する中核コンポーネントとして初期化されます。最後に、画像や音声などのマルチモーダル入力を効率的に管理するための MultiModalInputCacheServer
を起動し、全体の初期化が完了します。
推論処理の基本フロー
EngineCore
の推論処理は、次のような step メソッドとして分離されています。
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model
was executed.
"""
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
model_output = self.execute_model_with_error_logging(
self.model_executor.execute_model, # type: ignore
scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output) # type: ignore
return (engine_core_outputs,
scheduler_output.total_num_scheduled_tokens > 0)
スケジューラによるスケジューリング、それに基づくモデル実行、そして結果の更新という一連の流れが明確に分離されており、これらの処理によって根幹たる LLM の推論処理が実行されます。
これらの処理はそれぞれ Scheduler、ModelExecutor といったコンポーネントに委譲されており、それぞれの責任が明確に分離しつつ疎結合とすることで、柔軟な拡張性と保守性を実現しようとしていると感じます。
バッチキューによるパイプライン並列対応
パイプライン並列処理では、step_with_batch_queue
メソッドを使用してより高度な実行制御を行います。
def step_with_batch_queue(
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
"""Schedule and execute batches with the batch queue.
Note that if nothing to output in this step, None is returned."""
assert self.batch_queue is not None
engine_core_outputs = None
scheduler_output = None
# Try to schedule a new batch if the batch queue is not full
if not self.batch_queue.full():
scheduler_output = self.scheduler.schedule()
if scheduler_output.total_num_scheduled_tokens > 0:
future = self.model_executor.execute_model(scheduler_output)
self.batch_queue.put_nowait(
(future, scheduler_output)) # type: ignore
scheduled_batch = (scheduler_output is not None
and scheduler_output.total_num_scheduled_tokens > 0)
# If no more requests can be scheduled and the job queue is not empty,
# block until the first batch in the job queue is finished.
if not scheduled_batch and not self.batch_queue.empty():
future, scheduler_output = self.batch_queue.get_nowait()
# Blocking until the first result is available.
model_output = self.execute_model_with_error_logging(
lambda _: future.result(), scheduler_output)
self.batch_queue.task_done()
engine_core_outputs = (self.scheduler.update_from_output(
scheduler_output, model_output))
return engine_core_outputs, scheduled_batch
このメソッドは、パイプライン並列処理で複数のバッチを非同期に処理する仕組みを実装しています。まず、バッチキューに空きがある場合、スケジューラから新しいバッチをスケジューリングし、処理すべきトークンがあれば model_executor.execute_model()
を非同期で実行します。この実行結果は future
オブジェクトとして返され、スケジューラ出力とペアでバッチキューに格納されます。
重要なのは、新しいバッチがスケジューリングできなかった場合の処理です。キューが空でなければ、最初の完了待ちバッチを取得し、future.result()
でブロッキングして結果を待ちます。これにより、GPU が推論処理を実行している間も CPU 側は次のバッチの準備を進めることができ、GPU の計算リソースを最大限活用しつつ、パイプライン並列処理における「バブル」(GPU がアイドル状態になる空き時間)を最小化しています。
cf. パイプライン並列
EngineCoreProc: プロセス分離と ZMQ 通信
アーキテクチャ設計の意図
EngineCoreProc
は EngineCore
を継承し、コメントの通り「ZMQ-wrapper for running EngineCore in background process」として機能します。この設計により、推論処理を独立したプロセスで実行し、システム全体の安定性と拡張性を向上させています。
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
def __init__(self, ...):
# キューベースの入出力管理
self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], bytes]]()
# エンジンインデックスと識別子の設定
self.engine_index = engine_index
identity = self.engine_index.to_bytes(length=2, byteorder="little")
# ハンドシェイクとアドレス設定
with self._perform_handshakes(...) as addresses:
# バックグラウンドI/Oスレッドの起動
input_thread = threading.Thread(target=self.process_input_sockets, ...)
self.output_thread = threading.Thread(target=self.process_output_sockets, ...)
初期化処理では、まず入力と出力のための Python キューを作成します。input_queue
は EngineCoreClient からの推論リクエストを受け取るため、output_queue
は処理結果を返すために使用されます。これらのキューは、ZMQ ソケットとメインループ間の非同期通信を実現するバッファとして機能します。
次に、複数のエンジンインスタンスを識別するための engine_index
を設定し、これを2バイトのバイナリ形式(little endian)に変換して ZMQ の識別子として使用します。この識別子により、分散環境で複数のエンジンプロセスが動作する際に、各プロセスを一意に識別できます。
最後に _perform_handshakes
によってフロントエンドプロセスとの通信が確立されます。このハンドシェイク処理により ZMQ のアドレス情報が交換され、その後2つの専用スレッドが起動されます。process_input_sockets
スレッドは ZMQ ソケットから入力を受信し、process_output_sockets
スレッドは処理結果を送信します。この設計により、I/O 操作と推論処理を分離し、それぞれが独立して効率的に動作できるようになっています。
ハンドシェイクプロトコル
EngineCoreProc
の起動時には、フロントエンドプロセスとの間で次のような詳細なハンドシェイクを実行します。
@staticmethod
def startup_handshake(
handshake_socket: zmq.Socket,
local_client: bool,
headless: bool,
parallel_config: Optional[ParallelConfig] = None,
) -> EngineZmqAddresses:
# Send registration message.
handshake_socket.send(
msgspec.msgpack.encode({
"status": "HELLO",
"local": local_client,
"headless": headless,
}))
# Receive initialization message.
logger.info("Waiting for init message from front-end.")
if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
raise RuntimeError("Did not receive response from front-end "
f"process within {HANDSHAKE_TIMEOUT_MINS} "
f"minutes")
init_bytes = handshake_socket.recv()
init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
init_bytes, type=EngineHandshakeMetadata)
logger.debug("Received init message: %s", init_message)
if parallel_config is not None:
for key, value in init_message.parallel_config.items():
setattr(parallel_config, key, value)
return init_message.addresses
ハンドシェイクプロトコルでは、まずエンジンプロセスがフロントエンドに対して "HELLO" ステータスを含む登録メッセージを送信します。このメッセージには local_client
(ローカル接続)と headless
(GUI を持たない実行環境)の情報が含まれ、フロントエンドが設定を決定するために使用されます。
その後、エンジンプロセスは指定されたタイムアウト時間(デフォルトで数分)内にフロントエンドからの初期化メッセージを待機します。この待機には zmq.Socket.poll()
を使用し、ノンブロッキングでメッセージの到着をチェックします。タイムアウトが発生した場合は RuntimeError
を発生させ、システム管理者に通信障害を通知します。
フロントエンドからの初期化メッセージには、ZMQ アドレス情報と並列処理設定が含まれています。並列設定が提供された場合、エンジンプロセスは自身の parallel_config
を動的に更新し、分散環境での動作を確保します。最終的に返される EngineZmqAddresses
には、実際の通信で使用する ZMQ エンドポイントのアドレス情報が含まれており、これにより柔軟な分散構成と動的なネットワーク設定が実現されます。
バックグラウンドI/Oとメインループ
EngineCoreProc
の中核は、バックグラウンド I/O スレッドとメインループの連携です。
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core and return the outputs.
self._process_engine_step()
def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""
waited = False
while not self.engines_running and not self.scheduler.has_requests():
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
logger.debug("EngineCore waiting for work.")
waited = True
req = self.input_queue.get()
self._handle_client_request(*req)
if waited:
logger.debug("EngineCore loop active.")
# Handle any more client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
run_busy_loop
は、エンジンプロセスの心臓部となる無限ループを実装しています。このループは SIGINT または SIGTERM シグナルを受信するまで継続し、2つの主要な処理を順次実行します。まず _process_input_queue()
で入力キューを処理し、推論すべきリクエストがある状態にしてから、_process_engine_step()
で実際の推論処理を実行します。
_process_input_queue
の実装は、エンジンが動作中でなくスケジューラにもリクエストがない場合、入力キューからのリクエストを待機します。self.input_queue.get()
はブロッキング操作のため、新しいリクエストが到着するまでスレッドを休止状態にし、CPU リソースを無駄に消費しません。デバッグログでは待機状態を記録し、システム管理者が動作状況を把握できるようにしています。
重要な最適化として、待機が終了した後は get_nowait()
を使用してキューに蓄積された追加リクエストを一度に処理します。これにより、複数のリクエストが同時に到着した場合でも効率的にバッチ処理でき、全体のスループットを向上させています。
また、_handle_client_request
メソッドでは、受信したリクエストのタイプに応じて処理を実行します。推論リクエスト、統計情報要求、プロファイリング制御、アダプタ管理など、多様なリクエストタイプを効率的に振り分け、対応する内部メソッドを呼び出します。この設計により、ZMQ 通信による I/O 処理と CPU/GPU 集約的な推論処理を効率的に分離し、それぞれが最適な状態で動作できるようになっています。
DPEngineCoreProc: データ並列処理の実現
データ並列アーキテクチャの特徴
DPEngineCoreProc
は EngineCoreProc
を拡張し、「ZMQ-wrapper for running EngineCore in background process in a data parallel context」として、複数のエンジンインスタンス間での協調処理を実現します。
class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process
in a data parallel context."""
def __init__(self, ...):
# ウェーブ管理用のカウンタ
self.step_counter = 0
self.current_wave = 0
self.last_counts = (0, 0)
# データ並列ランクの設定
dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(..., engine_index=dp_rank)
def _init_data_parallel(self, vllm_config: VllmConfig):
# データ並列グループの初期化
self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
初期化処理では、データ並列処理特有の管理機構を設定します。step_counter
は現在のウェーブ内でのステップ数を追跡し、定期的な同期タイミングを決定するために使用されます。current_wave
は現在処理中のウェーブ番号を表し、全エンジンが協調して進行状況を管理します。last_counts
は前回の統計情報を保持し、パフォーマンス監視と最適化の判断材料として活用されます。
データ並列ランクの設定では、vllm_config.parallel_config.data_parallel_rank
からこのエンジンインスタンスの一意な識別番号を取得し、これを engine_index
として親クラスに渡します。これにより、複数のエンジンプロセスが並行実行される際に、各プロセスを明確に区別できます。
_init_data_parallel
では、データ並列グループの初期化を行います。stateless_init_dp_group()
は、分散環境での通信グループを確立し、all-reduce 操作などの集合通信を可能にします。このグループ通信により、各エンジンは独立してリクエストを処理しながらも、必要な同期ポイントで全体の状態を調整し、一貫性を保つことができます。
ウェーブ同期メカニズム
データ並列処理の核心は「ウェーブ」と呼ばれる同期単位です。各エンジンは独立してリクエストを処理しますが、一定のタイミングで全エンジンの状態を同期します。
def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core.
executed = self._process_engine_step()
self._maybe_publish_request_counts()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if not executed:
if not local_unfinished_reqs and not self.engines_running:
# All engines are idle.
continue
# We are in a running state and so must execute a dummy pass
# if the model didn't execute any ready requests.
self.execute_dummy_batch()
# 3) All-reduce operation to determine global unfinished reqs.
self.engines_running = self._has_global_unfinished_reqs(
local_unfinished_reqs)
if not self.engines_running:
if self.dp_rank == 0 or not self.has_coordinator:
# Notify client that we are pausing the loop.
logger.debug("Wave %d finished, pausing engine loop.",
self.current_wave)
# In the coordinator case, dp rank 0 sends updates to the
# coordinator. Otherwise (offline spmd case), each rank
# sends the update to its colocated front-end process.
client_index = -1 if self.has_coordinator else 0
self.output_queue.put_nowait(
(client_index,
EngineCoreOutputs(wave_complete=self.current_wave)))
# Increment wave count and reset step counter.
self.current_wave += 1
self.step_counter = 0
データ並列版の run_busy_loop
では、基本版と比較してより精密な同期制御が実装されています。各ステップで入力キューの処理と推論実行を行った後、_maybe_publish_request_counts()
でリクエスト数の統計情報を他のエンジンと共有します。これにより、負荷分散の状況を監視し調整を行えます。
また「ダミーバッチ」の実行によりエンジンの同期を維持します。実際のリクエストを処理しなかった場合でも、エンジンが実行中状態にあるときは execute_dummy_batch()
を実行します。これは、データ並列処理で全エンジンの同期を維持するため、実際の計算がなくても同期ポイントを確保する仕組みです。この設計により、一部のエンジンだけが遅れることを防ぎ、全体の処理効率を向上させています。
さらに _has_global_unfinished_reqs()
の結果によって all-reduce を行います。これは各エンジンのローカルな未完了リクエスト状態を集約し、グローバルな実行継続判定をします。全エンジンで未完了リクエストがなくなった場合、ウェーブが完了したとして次のウェーブに移行します。ランク0のエンジン(またはコーディネータを使用しない場合は各エンジン)がクライアントに完了通知を送信し、current_wave
をインクリメントして新しいウェーブを開始します。
効率的な同期最適化
データ並列処理では、頻繁な同期がパフォーマンスのボトルネックになり得ます。DPEngineCoreProc
では、次のような最適化を実装しています。
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
self.step_counter += 1
if self.step_counter % 32 != 0:
return True
return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
この同期最適化は、データ並列処理の性能向上において重要な役割を果たします。毎ステップで all-reduce 操作をすると、ネットワーク通信のレイテンシが推論処理全体のボトルネックになってしまいます。そこで、step_counter
を32でモジュロ演算し、32ステップに1回だけ実際の同期処理を実行する仕組みを採用しています。
31回は単純に True
を返すことで、エンジンの実行を継続させます。これにより、各エンジンは他のエンジンの状態を待つことなく、並列に推論処理を進められます。32回目のステップでは ParallelConfig.has_unfinished_dp()
を呼び出し、all-reduce 操作を通じて全エンジンの実際の状態を確認します。
この32ステップのバッファリングは、ネットワーク帯域幅と同期タイミングのバランスを見てのものと思われます。バッファが小さすぎると通信オーバーヘッドが大きくなり、大きすぎると各エンジン間の負荷不均衡が長時間放置されるので、ネットワーク通信のオーバーヘッドを効率的に削減しつつ、必要なタイミングでの同期を維持するためでしょう。
実装の特徴/考察
プロセス分離による安定性の向上
vLLM V1 の EngineCore アーキテクチャは、プロセス分離により次の利点を実現しています。
- フォルトトレラント性: エンジンプロセスの異常終了がクライアントプロセスに波及しない
- リソース分離: GPU リソースと CPU リソースの独立管理を実現
- スケーラビリティ: 複数エンジンインスタンスの動的管理が可能
ZMQ 通信による非同期処理
ZMQ(ZeroMQ)の採用により、次のような高性能な非同期通信を実現しています。
def process_input_sockets(self, input_addresses: list[str], ...):
"""Input socket IO thread."""
with ExitStack() as stack, zmq.Context() as ctx:
input_sockets = [
stack.enter_context(
make_zmq_socket(ctx, input_address, zmq.DEALER, ...))
for input_address in input_addresses
]
# ポーラーによる効率的なソケット監視
poller = zmq.Poller()
for input_socket in input_sockets:
poller.register(input_socket, zmq.POLLIN)
while True:
for input_socket, _ in poller.poll():
# メッセージの非同期受信と処理
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
# ... リクエスト処理
この ZMQ 実装は、まず複数の入力アドレスに対してそれぞれ DEALER ソケットを作成し、ExitStack
を使用してリソースの確実な解放を保証します。各ソケットは独立したクライアントからの接続を受け付け、負荷分散と冗長性を提供します。
zmq.Poller
の使用により、複数のソケットを効率的に監視できます。 select()
や poll()
と比較して、メッセージキューの状態も考慮するため、より正確なイベント検出が可能なようです。ポーリングによってメッセージが到着したソケットを特定し、そのソケットからのみ読み取りを行うため、不要なブロッキングを回避します。
メッセージ受信では recv_multipart(copy=False)
を使用し、ゼロコピー操作でパフォーマンスを最大化しています。メッセージは複数のフレームで構成され、最初のフレームがリクエストタイプ、残りのフレームがデータとなります。この設計により、大きなデータも効率的に転送でき、I/O 待機による CPU/GPU リソースの無駄を最小化し、高いスループットを実現しています。専用の I/O スレッドで通信を処理することで、メインの推論ループが通信待機でブロックされることなく、継続的に推論処理を実行できます。
商用構成への配慮
EngineCore アーキテクチャの設計には、明らかに商用環境での運用を意識した次のような配慮が見受けられます。
統計情報の詳細収集とモニタリング
vLLM V1 では、プロダクション環境での運用監視を前提とした包括的な統計収集機能を提供しています。各エンジンプロセスは、リクエスト処理数、処理時間、GPU/CPU 使用率、メモリ消費量などの詳細なメトリクスを継続的に収集し、これらの情報を外部の監視システムに送信できます。
特に _maybe_publish_request_counts()
メソッドでは、データ並列環境でのリクエスト分散状況を定期的に報告し、負荷バランシングの調整判断に必要な情報を提供します。これにより、システム管理者は各エンジンの稼働状況を可視化し、ボトルネックの早期発見や予防保守を実現できます。
動的構成変更とスケーラビリティ
商用環境では、トラフィックの変動に応じてシステムリソースを動的に調整します。vLLM V1 のアーキテクチャは、実行時でのデータ並列サイズの変更、エンジンインスタンスの追加・削除、メモリ割り当ての再配分などを、サービス停止なしで実行できるよう設計されています。
ZMQ による疎結合な通信設計により、新しいエンジンプロセスは既存のシステムに動的に参加でき、古いプロセスは処理中のリクエストを完了してから安全に退場できます。このホットスワップ機能は、24時間365日稼働するプロダクション環境では不可欠な要件です。
ロードバランシング戦略の多様性
vLLM V1 は、外部ロードバランサ(ALB、nginx など)、内部ラウンドロビン、ハイブリッド構成など、多様なロードバランシング戦略をサポートしています。EngineCoreProc
の設計では、複数の入力ソケットを並行監視することで、異なる負荷分散アルゴリズムからのリクエストを効率的に処理できます。
また、ウェーブ同期メカニズムにより、データ並列環境でも一貫したレスポンス時間を維持し、特定のエンジンに負荷が集中することを防ぎます。これらの機能は、単なる研究用ツールではなく、本格的なプロダクション環境での利用を前提とした設計思想を示していると言えるでしょう。
v0からの進化ポイント
vLLM V1 の EngineCore アーキテクチャは、v0 と比較して次の点で進化しています。
責任分離の明確化
v0 の LLMEngine
は、巨大なモノリシッククラスでした。多くのメソッドと多数のインスタンス変数を持ち、推論処理、スケジューリング、I/O 管理、統計収集、LoRA アダプタ管理、トレーシングなど、あらゆる責任を1つのクラスに集約していました。
class LLMEngine:
# 1800行以上の巨大クラス
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
# 200行以上の複雑なメソッド
# スケジューリング、実行、出力処理を全て担当
if not self._has_remaining_steps(seq_group_metadata_list):
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc) = self.scheduler[virtual_engine].schedule()
if not scheduler_outputs.is_empty():
outputs = self.model_executor.execute_model(execute_model_req)
if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)
v1 では、これを明確に分離しています。
-
EngineCore
: 純粋な推論オーケストレーション -
EngineCoreProc
: プロセス管理と ZMQ 通信 -
DPEngineCoreProc
: データ並列処理の協調制御
各クラスは単一責任原則に従い、保守性と拡張性を向上させています。
非同期処理の根本的な再設計
v0 の処理は基本的に同期型で、step()
メソッド内で全ての処理を順次実行していました。パイプライン並列処理も「将来機能」として位置づけられ、実際には NotImplementedError
を発生させる状態でした。
def step(self):
if self.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise.")
v1 では、設計レベルから非同期処理を前提とした作りに変更されています。
- ZMQ による真の非同期通信: 専用の I/O スレッドによる通信とメインループの分離
- バッチキューによるパイプライン並列:
step_with_batch_queue
での GPU リソースの最大活用 - バックグラウンド処理:
process_input_sockets
とprocess_output_sockets
スレッドによる並行処理
この変更により、v0 では実現困難だった高度な並列処理パターンが標準機能として利用可能になりました。
データ並列処理の対応
v0 では、データ並列処理は後付け的な実装で、基本的には複数の LLMEngine
インスタンスを外部で管理する形でした。エンジン間の協調も限定的で、同期メカニズムも原始的でした。
v1 の DPEngineCoreProc
では、データ並列処理を最初から対象として扱い機能を提供しています。
おわりに
vLLM V1 における EngineCore アーキテクチャの実装を詳しく見てきました。推論処理の核となる EngineCore
、プロセス分離と ZMQ 通信を担う EngineCoreProc
、そしてデータ並列処理を実現する DPEngineCoreProc
の3層構造により、性能と安定性を両立した設計になっていることがわかります。
特に印象的なのは、単なる機能追加ではなく、アーキテクチャレベルでの根本的な再設計を実現している点です。プロセス分離、非同期通信、ウェーブ同期といった高度な仕組みを組み合わせることで、スケーラブルで堅牢な推論システムを構築しています。
前回の記事でも書いたとおり、vLLM V1は商用構成を見越した設計になっている印象を受けます。ただ動かすだけでなく、スケールアウトや高スループットを意識した設計をしたいという意図が感じられ、やはり vLLM はただの推論ライブラリではなく、サービスのバックエンドとして安定して動作することを目指している...ような気がします。
V0 から V1 への移行は単なるバージョンアップ以上の意味を持ち、LLM 推論基盤としての成熟度を大きく向上させたと言えるのではないでしょうか。
Discussion