🥷

vLLMで高速に大規模言語モデルの生成を行う

2024/01/10に公開

概要

vLLMはLLMの生成を高速に行うためのライブラリです。
ここではLlama-2-7b-chat-hfをvLLMで簡単に使う方法を紹介します。
vLLMの環境構築はここを参照してください。より詳細なコードはこちら

使い方

from vllm import LLM, SamplingParams


model_path = "meta-llama/Llama-2-7b-chat-hf"
temperature = 1.0
top_p = 1.0
max_tokens = 128
quantization = None
    
llm = LLM(
    model=model_path,
    quantization=quantization,
    dtype="bfloat16",
)

sampling_params = SamplingParams(
    temperature=temperature,
    top_p=top_p,
    max_tokens=max_tokens
)

prompts = ["こんにちは", "おはようございます", "おやすみなさい"]

responses = llm.generate(prompts, sampling_params=sampling_params)

for prompt, response in zip(prompts, responses):
    print("prompt:", prompt)
    print("output:", response.outputs[0].text.strip())
    print("logprob:", response.outputs[0].cumulative_logprob)

解説

  • LLMクラスのmodel引数には、HuggingFaceのモデル名を指定します。その他モデルの設定に関する引数を指定することができます。ここではモデルをbfloat16で動作させるためにdtype引数を指定しています。
  • SamplingParamsクラスの引数には、生成時のサンプリング方法に関するパラメータを指定します。
  • 入力はstr型のリストで構成されます。
  • generateメソッドの戻り値はResponseクラスのインスタンスです。Responseクラスのoutputs属性には、生成されたトークン列が格納されています。outputs属性はOutputクラスのリストです。Outputクラスのtext属性には、トークン列を文字列に変換したものが格納されています。Outputクラスのcumulative_logprob属性には、生成されたトークン列の対数尤度が格納されています。

Discussion