TransformersでLLMの出力をストリーム生成してQOL向上
Transformers の TextStreamer
機能を使ってストリーム生成を行います。また、Gradio の ChatInterface
と組み合わせて快適にチャットするサンプルコードも紹介します。
今回の環境
- Python 3.10.12
- Transformers 4.34.1
- Gradio 3.50.2
今回のコードを試せる Colab ノートブック:
TextStreamer でストリーム出力
TextStreamer
は Transformers ライブラリの機能で、これを使うとモデルの generate()
関数で生成する際にストリームでテキストをプリントしてくれます。
以下は elyza/ELYZA-japanese-Llama-2-7b-fast-instruct
を使ったコード例:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
MODEL_NAME = "elyza/ELYZA-japanese-Llama-2-7b-fast-instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
load_in_4bit=True,
torch_dtype=torch.float16,
device_map={"": 0},
)
model.eval()
# ストリーマーを定義
streamer = TextStreamer(
tokenizer,
skip_prompt=False, # 入力文(ユーザーのプロンプトなど)を出力するかどうか
skip_special_tokens=False, # その他のデコード時のオプションもここで渡す
)
# Llama2形式のプロンプト
prompt = """<s>[INST]<<SYS>>
優秀で聡明なアシスタントAIとしてユーザーの要望に応じなさい。
<</SYS>>
猫は液体ですか? [/INST] """
inputs = tokenizer(
prompt,
return_tensors="pt",
add_special_tokens=False, # 先頭に<s>が付与されるのを回避
)
_ = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
top_k=20,
top_p=0.95,
num_beams=1, # 現時点では2以上を設定するとエラー (ビームサーチは使えない)
streamer=streamer, # ストリーマーを渡す
)
これを実行すると、以下がストリームで出力されます。
<s> [INST]<<SYS>>
優秀で聡明なアシスタントAIとしてユーザーの要望に応じなさい。
<</SYS>>
猫は液体ですか? [/INST] 「猫が液体である」という説も、確かに有り得ます。
しかし、「液体の状態を維持するために、常に一定の圧力がかかる必要がある」という条件に鑑みて考えると、現実世界では非常に難しいことでしょう…。</s>
生成が終わってから一度に出力されるわけではないため、待ち時間が軽減されてQOL向上間違いなしです。
Gradio の ChatInterface でストリーム返信
Gradio にはチャット専用のインターフェースである ChatInterface
が用意されています。
これを使うことでかなり簡単にストリーム生成に対応したチャットデモを作成することができます。
ここでは TextStreamer
のかわりに TextIteratorStreamer
を使用して、イテレーターから生成された文字を扱います。
コード例:
from threading import Thread
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import gradio as gr
SYSTEM_PROMPT = "あなたは誠実で優秀な日本人のアシスタントです。"
MODEL_NAME = "elyza/ELYZA-japanese-Llama-2-7b-fast-instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
load_in_4bit=True,
torch_dtype=torch.float16,
device_map={"": 0},
)
model.eval()
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True, # 今回は返信に使うため、入力文は返さない
skip_special_tokens=True, # </s> などの特殊トークンも不要
)
# 履歴とメッセージからプロンプトを作成する
def compose_prompt(message: str, history: list[list[str]]):
prompt = f"""<s>[INST]<<SYS>>
{SYSTEM_PROMPT}
<</SYS>>
"""
if len(history) == 0:
prompt += f"{message} [/INST] "
return prompt
else:
first_pair = history[0]
[user, assistant] = first_pair
prompt += f"{user} [/INST] {assistant} </s>"
for pair in history[0:]:
[user, assistant] = pair
prompt += f"<s>[INST] {user} [/INST] {assistant} </s>"
prompt += f"<s>[INST] {message} [/INST] "
return prompt
# ストリーマーを返してあげる関数
async def gen_stream(
prompt: str,
) -> TextIteratorStreamer:
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
config = dict(
**input_ids,
max_new_tokens=512,
streamer=streamer,
do_sample=True,
top_k=20,
top_p=0.95,
temperature=1.0,
num_beams=1, # 現時点では2以上を設定するとエラー (ビームサーチは使えない)
)
thread = Thread(target=model.generate, kwargs=config)
thread.start()
return streamer
async def chat(message, history):
prompt = compose_prompt(message, history)
streamer = await gen_stream(prompt)
total_response = ""
print(prompt, end="")
for output in streamer:
if not output:
continue
print(output, end="")
# output は新たにデコードされた文字のみが入っている
total_response += output
total_response = "\n".join(
[line.lstrip() for line in total_response.split("\n")]
) # 左側に謎の空白が発生する場合があるので除去する
# ここでは新規の文字ではなく応答文全体を返す必要がある
# そのため過去に生成した文字はとっておく必要がある
yield total_response
# デモを起動
demo = gr.ChatInterface(chat).queue()
demo.launch(share=True)
実行例:
これにより生成している間の待ち時間がなくなり、快適にチャットが可能になりました。
(動画を撮影するのが面倒でストリーム生成されているところを見せられないので、実際に動かして体験してみるとよいかもです)
Colab ノートブック:
おわり
生成が終わるまで待ってから出力を確認する日々とはもうおさらばです。
モデル名やプロンプトを構成する関数を変更すれば、他のインストラクション系モデルにも使えるため、チャット形式で確認をしたいときに便利です。
また、インストラクション系じゃないモデル(stockmark/stockmark-13b
など)であっても、streamer
を
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=False, # プロンプトも返すようにする
skip_special_tokens=True,
)
のようにすることで、それっぽい感じに表示されるようになり便利です。
参考
Discussion