🙆♀️
FastAPIでStreamingResponseに対応したrinna APIを作ってみる
はじめに
rinnaをサービスに組み込みやすくするためにapi化したけどやっぱりchatGPTみたいにストリーミングしたほうがユーザー体験が良いってことでStreamingResponseに対応したrinna APIを作ってみた。
生成された出力をStreamingする方法
transformersのTextIteratorStreamerクラスを使えば良いらしい。
使い方は簡単で
TextIteratorStreamerにtokenzierを渡して
tok = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
streamer = TextIteratorStreamer(tok)
スレッド立ててmodel.generateに渡すだけ
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
generated_text
コード
必要ライブラリ
generate.py
import asyncio
from threading import Thread
from typing import AsyncIterator
from pydantic import BaseModel
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer
)
import os
import json
requestクラス作成
generate.py
class request(BaseModel):
messages: list
role: bool = True
max_new_tokens: int = 512
temperature: float = 0.8
メインのクラス
request2promptでリクエストを入力用文字列に変換する
rinna-baseを独自でFTした場合とかは対話形式に対応してないので
リクエストでプロンプトを対話形式にするか選択できるようにした
role = True
[{
"speaker": "ユーザー",
"text": "東京のおすすめ観光スポットは?"
}]
↓
ユーザー:東京のおすすめ観光スポットは?<NL>
システム:
role = false
[{
"speaker": "ユーザー",
"text": "東京のおすすめ観光スポットは?"
}]
↓
東京のおすすめ観光スポットは?<NL>
generate.py
class rinna:
def __init__(self,model_dir):
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", torch_dtype=torch.float16)
def request2prompt(self,request):
if request.role:
prompt = [
f"{uttr['speaker']}: {uttr['text']}"
for uttr in request.messages
]
# print(prompt)
prompt = "<NL>".join(prompt)
prompt = (
prompt
+ "<NL>"
+ "システム: "
)
else:
prompt = [
f"{uttr['text']}"
for uttr in request.messages
]
prompt = "<NL>".join(prompt) + "<NL>"
return prompt
メインの生成部分
generate.py
async def generate_stream(self,request) -> AsyncIterator[str]:
prompt = self.request2prompt(request)
inputs = self.tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
streamer = TextIteratorStreamer(self.tokenizer)
generation_kwargs = dict(
inputs.to(self.model.device),
streamer=streamer,
max_new_tokens=request.max_new_tokens,
do_sample=True,
temperature=request.temperature,
pad_token_id=self.tokenizer.pad_token_id,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
bad_words_ids=[[self.tokenizer.bos_token_id]],
num_return_sequences=1,
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
for output in streamer:
if not output:
continue
print(output)
if "</s>" not in output:
yield json.dumps({
"speaker": "システム",
"text":output.replace("<NL>", "\n"),
"continue":True}, ensure_ascii=False)
else:
yield json.dumps({
"speaker": "システム",
"text":output.replace("<NL>", "\n").replace("</s>", ""),
"continue":False},
ensure_ascii=False)
await asyncio.sleep(0)
async def generate(self,request):
prompt = self.request2prompt(request)
token_ids = self.tokenizer.encode(
prompt, add_special_tokens=False, return_tensors="pt")
with torch.no_grad():
output_ids = self.model.generate(
token_ids.to(self.model.device),
do_sample=True,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
pad_token_id=self.tokenizer.pad_token_id,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
output = self.tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
output = output.replace("<NL>", "\n").replace("</s>", "")
res_message = {
"speaker": "システム",
"text": output
}
return res_message
あとはこれをFastAPIで呼ぶだけ
rinna_api.py
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import asyncio
from threading import Thread
from typing import AsyncIterator
from generate import rinna, request
rinna_ppo = rinna("/mnt/efs/llm/model/rinna/japanese-gpt-neox-3.6b-instruction-ppo")
app = FastAPI()
@app.get("/")
def read_root():
return {"Hello": "World"}
@app.get("/chat/")
def read_root():
return {"status": "ready to chat"}
@app.post("/chat/")
async def response(request:request):
return rinna_ppo.generate(request)
@app.post("/chat-stream/")
async def response_stream(request:request):
return StreamingResponse(rinna_ppo.generate_stream(request))
こんな感じのリクエストを投げると
{
"messages":
[{
"speaker": "ユーザー",
"text": "東京のおすすめ観光スポットは?"
}],
"max_new_tokens":256 ,
"temperature":0.8 ,
"role":True ,
}
こんな感じで細切れで送られてきてるのが確認出来た
出力に入力プロンプトが入ってしまうのはtransformer側の仕様っぽい
StreamingResponse
Discussion