🙆‍♀️

FastAPIでStreamingResponseに対応したrinna APIを作ってみる

2023/07/03に公開

はじめに

rinnaをサービスに組み込みやすくするためにapi化したけどやっぱりchatGPTみたいにストリーミングしたほうがユーザー体験が良いってことでStreamingResponseに対応したrinna APIを作ってみた。

生成された出力をStreamingする方法

transformersのTextIteratorStreamerクラスを使えば良いらしい。
使い方は簡単で
TextIteratorStreamerにtokenzierを渡して

tok = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
streamer = TextIteratorStreamer(tok)

スレッド立ててmodel.generateに渡すだけ

generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
    generated_text += new_text
generated_text

コード

必要ライブラリ

generate.py
import asyncio
from threading import Thread
from typing import AsyncIterator
from pydantic import BaseModel

import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    TextIteratorStreamer
)
import os
import json

requestクラス作成

generate.py
class request(BaseModel):
    messages: list
    role: bool = True
    max_new_tokens: int = 512
    temperature: float = 0.8

メインのクラス
request2promptでリクエストを入力用文字列に変換する
rinna-baseを独自でFTした場合とかは対話形式に対応してないので
リクエストでプロンプトを対話形式にするか選択できるようにした

role = True

[{
  "speaker": "ユーザー",
  "text": "東京のおすすめ観光スポットは?"
}]
↓
ユーザー:東京のおすすめ観光スポットは?<NL>
システム:

role = false

[{
  "speaker": "ユーザー",
  "text": "東京のおすすめ観光スポットは?"
}]
↓
東京のおすすめ観光スポットは?<NL>
generate.py
class rinna:

    def __init__(self,model_dir):
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False)
        self.model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", torch_dtype=torch.float16)


    def request2prompt(self,request):
        if request.role:
            prompt = [
                f"{uttr['speaker']}: {uttr['text']}"
                for uttr in request.messages
            ]
            # print(prompt)
            prompt = "<NL>".join(prompt)
            prompt = (
                prompt
                + "<NL>"
                + "システム: "
            )

        else:
            prompt = [
                f"{uttr['text']}"
                for uttr in request.messages
            ]
            prompt = "<NL>".join(prompt) + "<NL>"
            
        return prompt

メインの生成部分

generate.py
    async def generate_stream(self,request) -> AsyncIterator[str]:

        prompt = self.request2prompt(request)
        inputs = self.tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
        streamer = TextIteratorStreamer(self.tokenizer)

        generation_kwargs = dict(
            inputs.to(self.model.device),
            streamer=streamer,
            max_new_tokens=request.max_new_tokens,
            do_sample=True,
            temperature=request.temperature,
            pad_token_id=self.tokenizer.pad_token_id,
            bos_token_id=self.tokenizer.bos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            bad_words_ids=[[self.tokenizer.bos_token_id]],
            num_return_sequences=1,
        )

        thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()

        for output in streamer: 
            if not output:
                continue
            print(output)
            if "</s>" not in output:
                yield json.dumps({
                    "speaker": "システム",
                    "text":output.replace("<NL>", "\n"),
                    "continue":True}, ensure_ascii=False)
            else:
                yield json.dumps({
                    "speaker": "システム",
                    "text":output.replace("<NL>", "\n").replace("</s>", ""),
                    "continue":False}, 
                    ensure_ascii=False)
            await asyncio.sleep(0)
            

    async def generate(self,request):

        prompt = self.request2prompt(request)
        token_ids = self.tokenizer.encode(
        prompt, add_special_tokens=False, return_tensors="pt")

        with torch.no_grad():
            output_ids = self.model.generate(
                token_ids.to(self.model.device),
                do_sample=True,
                max_new_tokens=request.max_new_tokens,
                temperature=request.temperature,
                pad_token_id=self.tokenizer.pad_token_id,
                bos_token_id=self.tokenizer.bos_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )

        output = self.tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
        output = output.replace("<NL>", "\n").replace("</s>", "")
        res_message = {
            "speaker": "システム",
            "text": output
        }

        return res_message 

あとはこれをFastAPIで呼ぶだけ

rinna_api.py
from fastapi import FastAPI
from fastapi.responses import StreamingResponse

import asyncio
from threading import Thread
from typing import AsyncIterator

from generate import rinna, request

rinna_ppo = rinna("/mnt/efs/llm/model/rinna/japanese-gpt-neox-3.6b-instruction-ppo")

app = FastAPI()

@app.get("/")
def read_root():
    return {"Hello": "World"}


@app.get("/chat/")
def read_root():
    return {"status": "ready to chat"}


@app.post("/chat/")
async def response(request:request):
    return rinna_ppo.generate(request)

@app.post("/chat-stream/")
async def response_stream(request:request):
    return StreamingResponse(rinna_ppo.generate_stream(request))

こんな感じのリクエストを投げると

 {
  "messages": 
      [{
  	"speaker": "ユーザー",
  	"text": "東京のおすすめ観光スポットは?"
      }],
  "max_new_tokens":256 ,
  "temperature":0.8 ,
  "role":True ,
}

こんな感じで細切れで送られてきてるのが確認出来た
出力に入力プロンプトが入ってしまうのはtransformer側の仕様っぽい


StreamingResponse

Discussion