バッチ化対応APIサーバで rinna 3.6b のスループットを実測する
はじめに
前回の記事ではテキスト生成APIサーバのスループットを高めるbatching algorithmsについて紹介しました。今回は実際にAPIサーバに対して負荷テストを実施することで処理能力を実測します。dynamic batchingが可能なFasterTransformer+Triton Inference Serverとcontinuous batchingが可能なvLLMを比較します。モデルはHugging Faceで公開されているrinna/japanese-gpt-neox-3.6b-instruction-ppo
を利用します。
APIサーバのセットアップ
FasterTransformer+Triton Inference Server
Triton Inference Serverがプリインストールされているdocker imageを利用します。dockerをインストール済みのLinuxマシン上で手順を確認しました。
dockerコンテナ内からGPUにアクセスするためにはNVIDIA Container Toolkitが必要なのでインストールしておきましょう。
今回使いたいFasterTransformerは上記のimageには含まれていないので、こちらのリポジトリあるDockerfileを利用してFasterTransformerをインストールしたimageをビルドします。
v1.3をcloneしてREADMEの通りにdocker imageをビルドします
git clone --branch v1.3 https://github.com/triton-inference-server/fastertransformer_backend.git
cd fastertransformer_backend
export WORKSPACE=$(pwd)
export CONTAINER_VERSION=22.07
export TRITON_DOCKER_IMAGE=triton_with_ft:${CONTAINER_VERSION}
docker build --rm \
--build-arg TRITON_VERSION=${CONTAINER_VERSION} \
-t ${TRITON_DOCKER_IMAGE} \
-f docker/Dockerfile \
.
環境によってはビルドに1時間弱ほどかかります。
docker images
triton_with_ft:22.07
という名前になりました。
REPOSITORY TAG IMAGE ID CREATED SIZE
triton_with_ft 22.07 4bb65bac4b62 About a minute ago 26.9GB
rinna/japanese-gpt-neox-3.6b-instruction-ppo
はGPT-NeoXアーキテクチャなので対応するドキュメント[1]を読みながら進めていきます。
ここからはコンテナを起動してコンテナ上で作業します。
docker run -it --rm --gpus=all --shm-size=1g --ulimit memlock=-1 -v ${WORKSPACE}:${WORKSPACE} -w ${WORKSPACE} ${TRITON_DOCKER_IMAGE} bash
Hugging Faceからrinna/japanese-gpt-neox-3.6b-instruction-ppo
をダウンロードします。
python3 -c 'from huggingface_hub import snapshot_download; snapshot_download("rinna/japanese-gpt-neox-3.6b-instruction-ppo", local_dir="/tmp/rinna-3.6b-ppo")'
モデルの変換スクリプトを取得して実行します。
git clone https://github.com/NVIDIA/FasterTransformer.git
python3 FasterTransformer/examples/pytorch/gptneox/utils/huggingface_gptneox_convert.py \
-o all_models/gptneox/fastertransformer/1 \
-i /tmp/rinna-3.6b-ppo \
-i_g 1 \
-m_n gptneox \
-weight_data_type fp16
完了すると all_models/gptneox/fastertransformer/1/1-gpu
ディレクトリが作成されているはずです。
-
all_models/gptneox/fastertransformer/1/1-gpu/config.ini
- モデルの設定ファイルです。モデルの変換時に生成されるので特に変更する必要はありません。
-
all_models/gptneox/fastertransformer/config.pbtxt
- 入出力やオプションを設定するファイルです。今回の環境に合わせて一部書き換えます。
all_models/gptneox/fastertransformer/config.pbtxt
の以下の部分を編集します。
tensor_para_size: "2"
-> "1"
model_checkpoint_path: "/workspace/ft/models/ft/gptneox/"
-> "all_models/gptneox/fastertransformer/1/1-gpu"
ここまででサーバを起動する準備ができました。次のコマンドでバックグラウンドでサーバを起動します。
mpirun -n 1 --allow-run-as-root /opt/tritonserver/bin/tritonserver --model-repository=all_models/gptneox/ &
しばらくすると以下のポートで待ち受けが開始した旨が表示されます。
I0705 08:31:17.831163 591 grpc_server.cc:4819] Started GRPCInferenceService at 0.0.0.0:8001
I0705 08:31:17.831364 591 http_server.cc:3477] Started HTTPService at 0.0.0.0:8000
I0705 08:31:17.872560 591 http_server.cc:184] Started Metrics Service at 0.0.0.0:8002
HTTPService宛てにリクエストする簡単なクライアントをpythonで書きます。入力として input_ids, input_lengths, request_output_len は最低限必要です。その他のパラメータについてはドキュメント[2]を参照してください。
# client.py
import requests
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-ppo", use_fast=False)
input_ids = tokenizer("ユーザー: 温度と湿度の関係は?<NL>システム: ", add_special_tokens=False).input_ids
data = {
"inputs": [
{
"name": "input_ids",
"datatype": "UINT32",
"shape": [1, len(input_ids)],
"data": input_ids,
},
{
"name": "input_lengths",
"datatype": "UINT32",
"shape": [1, 1],
"data": [len(input_ids)],
},
{
"name": "request_output_len",
"datatype": "UINT32",
"shape": [1, 1],
"data": [128],
},
]
}
response = requests.post(
"http://localhost:8000/v2/models/fastertransformer/infer", json=data
)
for output in response.json()["outputs"]:
if output["name"] == "output_ids":
print(tokenizer.decode(output["data"], skip_special_tokens=True))
実行するとこのようなレスポンスが出力されました。
python3 client.py
ユーザー: 温度と湿度の関係は?<NL>システム: 温度と湿度は、ともに空気の温度と湿度を測定し、測定された値を比較するために使用されます。温度と湿度は、空気の温度と湿度を測定し、測定された値を比較するために使用されます。
デフォルトではdynamic batching[3]は有効化されていませんでした。clientをバックグラウンドで複数起動してみると一定間隔でレスポンスが帰ってきています。
for i in {1..10}; do python3 client.py & done
dynamic batchingを有効にするためには all_models/gptneox/fastertransformer/config.pbtxt
を編集してinputの手前にdynamic_batchingとbatch_inputの項目を追加します。またinput_idsにallow_ragged_batch: trueを追加します。
...
dynamic_batching {
max_queue_delay_microseconds: 50000
}
batch_input [
{
kind: BATCH_ITEM_SHAPE
target_name: "input_ids_item_shape"
data_type: TYPE_INT32
source_input: "input_ids"
}
]
input [
{
name: "input_ids"
data_type: TYPE_UINT32
dims: [ -1 ]
allow_ragged_batch: true
},
...
変更したあとにサーバを再起動して先程と同様にclientをバックグラウンドで複数起動してみます。
for i in {1..10}; do python3 client.py & done
複数のレスポンスが同時に返ってくるようになりました。
vLLM
PyPIからインストールします。
pip install vllm==0.1.2 "pydantic==1.*"
rinna/japanese-gpt-neox-3.6b-instruction-ppo
ではtokenize.enodeの際にadd_special_tokens=Falseオプションを指定する必要があるので変更します。
sed -i -e 's/encode(prompt)/encode(prompt, add_special_tokens=False)/' $(python3 -c "import vllm; print(vllm.__path__[0])")/engine/llm_engine.py
次のコマンドでサーバを起動できます。今回は試していませんがOpenAI互換のサーバも用意されています[4]。
python3 -m vllm.entrypoints.api_server \
--model rinna/japanese-gpt-neox-3.6b-instruction-ppo \
--tokenizer-mode slow \
--host 0.0.0.0 &
しばらくすると起動した旨が表示されます。
INFO 07-06 07:44:04 llm_engine.py:131] # GPU blocks: 1094, # CPU blocks: 661
INFO: Started server process [1172]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
こちらも簡単なクライアントをpythonで書いてみます。promptにテキスト、max_tokensに最大のトークン数を指定します。ここではtemperatureは0を指定してgreedy decodingにしました。他の設定可能なパラメータについてはSamplingParamsクラス[5]を参照してください。
# client.py
import requests
prompt = "ユーザー: 温度と湿度の関係は?<NL>システム: "
data = {
"prompt": prompt,
"max_tokens": 128,
"temperature": 0,
}
response = requests.post(
"http://localhost:8000/generate", json=data
)
print(response.json()["text"][0])
このようなテキストが出力されました。greedy decodingのためFasterTransformerの結果と一致していました。
ユーザー: 温度と湿度の関係は?<NL>システム: 温度と湿度は、ともに空気の温度と湿度を測定し、測定された値を比較するために使用されます。温度と湿度は、空気の温度と湿度を測定し、測定された値を比較するために使用されます。</s>
vLLMはデフォルトでcontinuous batching[6]を実施するのでclientをバックグラウンドで複数起動してみると複数のレスポンスがほぼ同時に帰ってきます。
for i in {1..10}; do python3 client.py & done
負荷テスト
今回はテスト用のプロンプトとしてJAQKETデータセット[7]を用いました。
プロンプトはrinna/japanese-gpt-neox-3.6b-instruction-ppo
のフォーマット[8]に合わせて"ユーザー: {question}<NL>システム: "
として問題の答えが生成されることを期待しています。プロンプトのトークン数は平均35.2でした。短めのプロンプトなのでバッチ化によりスループットの向上が期待できます。今回は出力されるトークン数のばらつきによる影響を見るため、eosトークン (</s>) が出力された時点で生成を打ち切ります。
負荷テストにはlocustを使用してクライアントのコードを実行しました。ユーザ数を128まで徐々に増やしながらテストします。
APIサーバはクラウド上の 4x vCPU & 1x T4 GPU というスペックのマシン上に立てています。
FasterTransformer+Triton Inference Server
request_output_len: 16
出力するトークン数の上限が16の設定です。ユーザ数が128に近いところでは50rpsほどのスループットで処理されていました。そのときの応答時間は2.5秒程度です。
request_output_len: 128
出力するトークン数の上限が128の設定です。ユーザ数が128に近いところでは12rpsほどのスループットで処理されていました。そのときの応答時間は10秒前後になっています。
vLLM
max_tokens: 16
出力するトークン数の上限が16の設定です。ユーザ数が128に近いところでは50rpsほどのスループットで処理されていました。vLLMの場合応答時間のばらつきが大きいことが特徴で、50th-percentileと95th-percentileの差が大きいです。continuous batchingを採用しているため短いレスポンスが応答時間の平均を引き下げていると考えられます。
max_tokens: 128
出力するトークン数の上限が128の設定です。ユーザ数が128に近いところでは30rpsほどのスループットで処理されていました。そのときの応答時間は50th-percentileでは2秒ほど、95th-percentileでは15秒前後になっています。
まとめ
dynamic batchingに対応するFasterTransformer+Triton Inference Serverとcontinuous batchingに対応するvLLMのAPIサーバに対して負荷テストを実施して処理能力を実測しました。dynamic batchingとcontinuous batchingの差が出にくいと考えられる出力トークン数の上限を16に抑えた設定ではどちらも同等のrpsを達成していました。出力トークン数の上限を128とした設定ではcontinuous batchingの利点が示されていてFasterTransformer+Triton Inference Serverでは12rpsほどだったところ、vLLMでは30rpsほどと2倍以上スループットが高いことが確認できました。
今回はモデルロード時のエラーのため動作確認できませんでしたがText Generation InferenceにもvLLMが提案したPaged Attentionが最近導入[9]されているほか、FasterTransformerも継続的に開発が続けられているので今後も注目していきたいですね。
前回← テキスト生成APIサーバのスループットを高めるbatching algorithms
-
https://github.com/triton-inference-server/fastertransformer_backend/blob/v1.4/docs/gptneox_guide.md ↩︎
-
https://github.com/triton-inference-server/fastertransformer_backend/blob/v1.4/docs/gptneox_guide.md#how-to-set-the-model-configuration ↩︎
-
https://zenn.dev/rinna/articles/7d10e61f694611#dynamic-batching-(request-level-scheduling) ↩︎
-
https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html#openai-compatible-server ↩︎
-
https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py ↩︎
-
https://zenn.dev/rinna/articles/7d10e61f694611#continuous-batching-(iteration-level-scheduling) ↩︎
-
https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-ppo ↩︎
Discussion