💬

TransformersでLLMの出力をストリーム生成してQOL向上

2023/10/31に公開

Transformers の TextStreamer 機能を使ってストリーム生成を行います。また、Gradio の ChatInterface と組み合わせて快適にチャットするサンプルコードも紹介します。

今回の環境

  • Python 3.10.12
  • Transformers 4.34.1
  • Gradio 3.50.2

今回のコードを試せる Colab ノートブック:

https://gist.github.com/p1atdev/4a13942fe5c1de875b8b34a59aa1c858#file-llm-chat-playground-ipynb

TextStreamer でストリーム出力

https://huggingface.co/docs/transformers/internal/generation_utils#transformers.TextStreamer

TextStreamer は Transformers ライブラリの機能で、これを使うとモデルの generate() 関数で生成する際にストリームでテキストをプリントしてくれます。

以下は elyza/ELYZA-japanese-Llama-2-7b-fast-instruct を使ったコード例:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer

MODEL_NAME = "elyza/ELYZA-japanese-Llama-2-7b-fast-instruct"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    load_in_4bit=True,
    torch_dtype=torch.float16,
    device_map={"": 0},
)
model.eval()

# ストリーマーを定義
streamer = TextStreamer(
    tokenizer,
    skip_prompt=False, # 入力文(ユーザーのプロンプトなど)を出力するかどうか
    skip_special_tokens=False, # その他のデコード時のオプションもここで渡す
)

# Llama2形式のプロンプト
prompt = """<s>[INST]<<SYS>>
優秀で聡明なアシスタントAIとしてユーザーの要望に応じなさい。
<</SYS>>

猫は液体ですか? [/INST] """

inputs = tokenizer(
    prompt, 
    return_tensors="pt", 
    add_special_tokens=False, # 先頭に<s>が付与されるのを回避
)

_ = model.generate(
    **inputs,
    max_new_tokens=256,
    do_sample=True,
    top_k=20,
    top_p=0.95,
    num_beams=1, # 現時点では2以上を設定するとエラー (ビームサーチは使えない)
    streamer=streamer, # ストリーマーを渡す
)

これを実行すると、以下がストリームで出力されます。

<s> [INST]<<SYS>>
優秀で聡明なアシスタントAIとしてユーザーの要望に応じなさい。
<</SYS>>

猫は液体ですか? [/INST]  「猫が液体である」という説も、確かに有り得ます。
しかし、「液体の状態を維持するために、常に一定の圧力がかかる必要がある」という条件に鑑みて考えると、現実世界では非常に難しいことでしょう…。</s>

生成が終わってから一度に出力されるわけではないため、待ち時間が軽減されてQOL向上間違いなしです。

Gradio の ChatInterface でストリーム返信

Gradio にはチャット専用のインターフェースである ChatInterface が用意されています。

https://www.gradio.app/docs/chatinterface

これを使うことでかなり簡単にストリーム生成に対応したチャットデモを作成することができます。

ここでは TextStreamer のかわりに TextIteratorStreamer を使用して、イテレーターから生成された文字を扱います。

コード例:

from threading import Thread

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import gradio as gr

SYSTEM_PROMPT = "あなたは誠実で優秀な日本人のアシスタントです。"

MODEL_NAME = "elyza/ELYZA-japanese-Llama-2-7b-fast-instruct"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    load_in_4bit=True,
    torch_dtype=torch.float16,
    device_map={"": 0},
)
model.eval()

streamer = TextIteratorStreamer(
    tokenizer,
    skip_prompt=True, # 今回は返信に使うため、入力文は返さない
    skip_special_tokens=True, # </s> などの特殊トークンも不要
)

# 履歴とメッセージからプロンプトを作成する
def compose_prompt(message: str, history: list[list[str]]):
    prompt = f"""<s>[INST]<<SYS>>
{SYSTEM_PROMPT}
<</SYS>>

"""
    if len(history) == 0:
        prompt += f"{message} [/INST] "
        return prompt
    else:
        first_pair = history[0]
        [user, assistant] = first_pair

        prompt += f"{user} [/INST] {assistant} </s>"

        for pair in history[0:]:
            [user, assistant] = pair
            prompt += f"<s>[INST] {user} [/INST] {assistant} </s>"
        
        prompt += f"<s>[INST] {message} [/INST] "
        return prompt


# ストリーマーを返してあげる関数
async def gen_stream(
    prompt: str,
) -> TextIteratorStreamer:
    input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)

    config = dict(
        **input_ids,
        max_new_tokens=512,
        streamer=streamer,
        do_sample=True,
        top_k=20,
        top_p=0.95,
        temperature=1.0,
        num_beams=1, # 現時点では2以上を設定するとエラー (ビームサーチは使えない)
    )

    thread = Thread(target=model.generate, kwargs=config)
    thread.start()

    return streamer

async def chat(message, history):
    prompt = compose_prompt(message, history)

    streamer = await gen_stream(prompt)

    total_response = ""

    print(prompt, end="") 

    for output in streamer:
        if not output:
            continue

        print(output, end="")
        # output は新たにデコードされた文字のみが入っている
        total_response += output
        total_response = "\n".join(
            [line.lstrip() for line in total_response.split("\n")]
        ) # 左側に謎の空白が発生する場合があるので除去する

        # ここでは新規の文字ではなく応答文全体を返す必要がある
        # そのため過去に生成した文字はとっておく必要がある
        yield total_response

# デモを起動
demo = gr.ChatInterface(chat).queue()
demo.launch(share=True)

実行例:

Stream chat demo using Gradio

これにより生成している間の待ち時間がなくなり、快適にチャットが可能になりました。

(動画を撮影するのが面倒でストリーム生成されているところを見せられないので、実際に動かして体験してみるとよいかもです)

Colab ノートブック:

https://gist.github.com/p1atdev/4a13942fe5c1de875b8b34a59aa1c858#file-llm-chat-playground-ipynb

おわり

生成が終わるまで待ってから出力を確認する日々とはもうおさらばです。

モデル名やプロンプトを構成する関数を変更すれば、他のインストラクション系モデルにも使えるため、チャット形式で確認をしたいときに便利です。

また、インストラクション系じゃないモデル(stockmark/stockmark-13b など)であっても、streamer

streamer = TextIteratorStreamer(
    tokenizer,
    skip_prompt=False, # プロンプトも返すようにする
    skip_special_tokens=True,
)

のようにすることで、それっぽい感じに表示されるようになり便利です。

参考

https://cockscomb.hatenablog.com/entry/streaming-with-huggingface-transformers

GitHubで編集を提案

Discussion