👋

ローカルLLMのストリームAPI

に公開

ローカルLLMでAPIを作成するPythonコードをご紹介しようと思います。
今回は通常の出力及びストリーム出力の2つをご紹介します。

最後にコード全文を載せますので、ご参照ください。

通常の出力

fastapiを使用してAPIを作成します。

サーバー

まずはサーバー側のコードから見ていきましょう。

from fastapi import FastAPI
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel
...
@app.post("/call")
async def call_llm_api(item: Item):
    return PlainTextResponse(content=call_llm(tokenizer, model, item.data))

fastapiを利用して、POSTリクエストを受け取り、ローカルLLMの出力を返すAPIです。

クライアント

次にクライアント側のコードを見ていきましょう。postメソッドを使用して、サーバーにリクエストを送信します。

import httpx
...
with httpx.Client() as client:
    response = client.post(url_call, json=data, timeout=timeout)
    print(response.text)

ストリーム出力

次にストリーム出力の場合を見ていきましょう。
ストリーム出力にはServer-Sent Events(SSE)を使用します。
Server-Sent Events(SSE)は、Webブラウザーとサーバー間の一方向の非同期通信方法です。頻繁なリクエスト送信を抑えることができるため、ストリーミング処理に向いています。

サーバー

先ほどとコードはほとんど変わりません。
StreamingResponseを使用すること、media_typeをtext/event-streamにすることで、ストリーム出力が可能になります。

from fastapi import FastAPI
from fastapi.responses import StreamingResponse
...
async def event_stream(items: Item):
    streamer = stream_response_generator(tokenizer, model, items)
    for new_text in streamer:
        if new_text != "":
            yield f"data: {new_text}\n\n"

@app.post("/stream", response_class=StreamingResponse)
async def stream_items(item: Item):
    return StreamingResponse(event_stream(item.data), media_type="text/event-stream")

クライアント

クライアント側のコードもほとんど代わりません。streamを使用して、サーバーにリクエストを送信します。

import httpx
...
with httpx.Client() as client:
    # POSTリクエストを送信
    with client.stream("POST", url_stream, json=data, timeout=timeout) as response:
        # レスポンスのステータスコードを確認
        if response.status_code == 200:
            # ストリーミングされたレスポンスを逐次処理
            for line in response.iter_lines():
                if line:
                    # SSE形式のデータを解析
                    if line.startswith("data: "):
                        event_data = line[6:]
                        print(f"受信したイベント: {event_data}")
        else:
            print(f"エラー: {response.status_code}")

コード全文

最後にコード全文を載せます。

ローカルLLM

今回はllm-jpのモデルを使用しました。1.8bのため、8GBのGPUでは問題なく動きます。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread


def initialize_model():
    tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-3-1.8b-instruct")
    model = AutoModelForCausalLM.from_pretrained(
        "llm-jp/llm-jp-3-1.8b-instruct",
        device_map="auto",
        torch_dtype=torch.bfloat16
    ).eval()
    return tokenizer, model

def stream_response_generator(tokenizer, model, chat):
    # チャットテンプレートの適用とトークナイズ
    tokenized_input = tokenizer.apply_chat_template(
        chat,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt"
    ).to(model.device)

    # TextIteratorStreamerの初期化
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    # 別スレッドでの生成処理
    generation_kwargs = {
        "input_ids": tokenized_input,
        "max_new_tokens": 100,
        "do_sample": True,
        "top_p": 0.95,
        "temperature": 0.7,
        "repetition_penalty": 1.05,
        "streamer": streamer,
    }
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # メインスレッドでのストリーム出力処理
    # for new_text in streamer:
    #     print(new_text, end="\n", flush=True)
    return streamer

def call_llm(tokenizer, model, chat):
    # チャットテンプレートの適用とトークナイズ
    tokenized_input = tokenizer.apply_chat_template(
        chat,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt"
    ).to(model.device)

    # 別スレッドでの生成処理
    generation_kwargs = {
        "input_ids": tokenized_input,
        "max_new_tokens": 100,
        "do_sample": True,
        "top_p": 0.95,
        "temperature": 0.7,
        "repetition_penalty": 1.05,
    }
    output = model.generate(**generation_kwargs)
    input_text = tokenizer.decode(tokenized_input[0], skip_special_tokens=True)
    output_text =tokenizer.decode(output[0], skip_special_tokens=True)
    return output_text[len(input_text):]

def main():
    tokenizer, model = initialize_model()
    chat = [
        {"role": "system", "content": "以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。"},
        {"role": "user", "content": "自然言語処理とは何か"},
    ]
    streamer = stream_response_generator(tokenizer, model, chat)
    for new_text in streamer:
        if new_text != "":
            print(new_text, end="\n", flush=True)

if __name__ == "__main__":
    main()

サーバー

import asyncio
import json
from typing import List, Dict

from fastapi import FastAPI
from fastapi.responses import PlainTextResponse, StreamingResponse
from pydantic import BaseModel

from local_llm_stream import initialize_model, stream_response_generator, call_llm


tokenizer, model = initialize_model()
app = FastAPI()

class Item(BaseModel):
    data: List[Dict[str, str]]

async def event_stream(items: Item):
    streamer = stream_response_generator(tokenizer, model, items)
    for new_text in streamer:
        if new_text != "":
            yield f"data: {new_text}\n\n"

@app.post("/stream", response_class=StreamingResponse)
async def stream_items(item: Item):
    return StreamingResponse(event_stream(item.data), media_type="text/event-stream")

@app.post("/call")
async def call_llm_api(item: Item):
    return PlainTextResponse(content=call_llm(tokenizer, model, item.data))

クライアント

import httpx
import json


# サーバーのURL
url_call = "http://localhost:8000/call"
url_stream = "http://localhost:8000/stream"
timeout = 10.0

# 送信するデータ
data = {
    "data": [
        {"role": "system", "content": "以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。"},
        {"role": "user", "content": "自然言語処理とは何か"},
    ]
}

print("call")
with httpx.Client() as client:
    response = client.post(url_call, json=data, timeout=timeout)
    print(response.text)

print("stream")
# HTTPクライアントの作成
with httpx.Client() as client:
    # POSTリクエストを送信
    with client.stream("POST", url_stream, json=data, timeout=timeout) as response:
        # レスポンスのステータスコードを確認
        if response.status_code == 200:
            # ストリーミングされたレスポンスを逐次処理
            for line in response.iter_lines():
                if line:
                    # SSE形式のデータを解析
                    if line.startswith("data: "):
                        event_data = line[6:]
                        print(f"受信したイベント: {event_data}")
        else:
            print(f"エラー: {response.status_code}")

実行コマンド

# サーバーの起動
uvicorn server:app --reload
# クライアントの実行
python client.py

実行結果例:

call
自然言語処理(Natural Language Processing, NLP)とは、人間の言語をコンピュータが理解し、生成できるようにする技術や方法論を指します。この分野では、テキスト、音声、画像などの形式で提供される情報を解析し、意味や意図を抽出したり、自動的に翻訳したり、対話システムを構築したりすることが可能になります。

NLPの主な目的は、以下のような要素を含むことが多いです:

1. テキストの理解と生成:
   - 自然な文章を理解して生成
stream
受信したイベント: 自然
受信したイベント: 言語
受信したイベント: 処理
受信したイベント: (NLP)とは、人間
受信したイベント: の言語
受信したイベント: を理解
...

参照

Discussion