🥷
vLLMで高速に大規模言語モデルの生成を行う
概要
vLLMはLLMの生成を高速に行うためのライブラリです。
ここではLlama-2-7b-chat-hfをvLLMで簡単に使う方法を紹介します。
vLLMの環境構築はここを参照してください。より詳細なコードはこちら。
使い方
from vllm import LLM, SamplingParams
model_path = "meta-llama/Llama-2-7b-chat-hf"
temperature = 1.0
top_p = 1.0
max_tokens = 128
quantization = None
llm = LLM(
model=model_path,
quantization=quantization,
dtype="bfloat16",
)
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens
)
prompts = ["こんにちは", "おはようございます", "おやすみなさい"]
responses = llm.generate(prompts, sampling_params=sampling_params)
for prompt, response in zip(prompts, responses):
print("prompt:", prompt)
print("output:", response.outputs[0].text.strip())
print("logprob:", response.outputs[0].cumulative_logprob)
解説
-
LLM
クラスのmodel
引数には、HuggingFaceのモデル名を指定します。その他モデルの設定に関する引数を指定することができます。ここではモデルをbfloat16
で動作させるためにdtype
引数を指定しています。 -
SamplingParams
クラスの引数には、生成時のサンプリング方法に関するパラメータを指定します。 - 入力はstr型のリストで構成されます。
-
generate
メソッドの戻り値はResponse
クラスのインスタンスです。Response
クラスのoutputs
属性には、生成されたトークン列が格納されています。outputs
属性はOutput
クラスのリストです。Output
クラスのtext
属性には、トークン列を文字列に変換したものが格納されています。Output
クラスのcumulative_logprob
属性には、生成されたトークン列の対数尤度が格納されています。
Discussion