🪔

LitServe で画像生成サーバーを建てる

2024/12/12に公開

この記事は(あんまりLLMに関係ないですが...) LLM・LLM活用 Advent Calendar 2024 シリーズ2 11日目の記事です。

https://qiita.com/advent-calendar/2024/large-language-model

はじめに

近年、テキスト生成や画像生成などの様々な機械学習モデルが登場しています。テキスト生成はその大規模な流行から、vLLM に代表されるような OpenAI API と互換性のある使いやすい推論サーバーや推論のためのライブラリが整っています。

しかし、画像生成では手軽に推論サーバーを建てるためのツールは整備されていないように感じます。AUTOMATIC1111 の WebUIComfyUI などの WebUI のついた推論ツールには、HTTP API を利用できる機能があることがありますが、 非常に多機能 であるため、使いこなすためには 一定の学習コスト がかかります。

そこで今回は、LitServe を用いて シンプルな画像生成サーバー を建てる方法を紹介します。

LitServe とは何ですの?

LitServe は、Lightning AI が開発している機械学習モデルの推論サーバーを簡単に建てるためのツールです。

https://github.com/Lightning-AI/LitServe

単純にサーバーを建てるだけでなく、オートスケーリングや認証の機能もつけられるそうですが、今回は深いところまでは触らないです。

このライブラリ自体は画像生成専用というわけではなく、もちろんテキスト生成モデルでも使えますし、音声認識など他のタスクでも使うことができます。詳細は公式の Examples を参照してください:

https://lightning.ai/docs/litserve/examples

今回の目標

今回は、LitServe と diffusers を使って必要最低限の機能を持った画像生成サーバーを作成します。画像生成モデルとしては、OnomaAIResearch/Illustrious-xl-early-release-v0 を使用します。

完成品

完成品のコードは以下のレポにあります:

https://github.com/p1atdev/sdxl-litserve

使用したモデル:

プロジェクト作成

今回はPython のバージョン管理や仮想環境管理として uv を使いますが各自好きなものを使って良いです。
torchdiffusers 等、画像生成に必要案ものを事前にインストールしておきます。

litserve をインストールするには以下のコマンドを実行します:

uv add litserve

最終的な依存関係は以下の通りです:

pyproject.toml
pyproject.toml
[project]
name = "sdxl-litserve"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11,<3.12"
dependencies = [
    "torch<2.5,>=2.4",
    "torchvision>=0.19.0",
    "safetensors>=0.4.5",
    "hf-transfer>=0.1.8", 
    "tqdm>=4.67.1",
    "transformers>=4.47.0",
    "diffusers>=0.31.0",
    "litserve>=0.2.5",
    "accelerate>=1.2.0",
    "pydantic>=2.10.3",
]

[tool.uv]
# これはフォーマッター
dev-dependencies = ["ruff>=0.8.0"]

[tool.uv.sources]
torch = [{ index = "pytorch-cu124", marker = "platform_system != 'Darwin'" }]
torchvision = [
    { index = "pytorch-cu124", marker = "platform_system != 'Darwin'" },
]

[[tool.uv.index]]
name = "pytorch-cu124"
url = "https://download.pytorch.org/whl/cu124"
explicit = true

今回はリクエストパラメータのバリデーションで pydantic を使います。

実装

公式の一番シンプルな例では次のようなコードが示されています。

import litserve as ls

class SimpleLitAPI(ls.LitAPI):
    def setup(self, device):
        self.model1 = lambda x: x**2
        self.model2 = lambda x: x**3

    def decode_request(self, request):
        return request["input"]

    def predict(self, x):
        squared = self.model1(x)
        cubed = self.model2(x)
        output = squared + cubed
        return {"output": output}

    def encode_response(self, output):
        return {"output": output}

if __name__ == "__main__":
    api = SimpleLitAPI()
    server = ls.LitServer(api, accelerator="gpu")
    server.run(port=8000)

ls.LitAPI を継承して、setup メソッドでモデルを初期化 し、decode_request メソッドで POST リクエストの body をパース し、predict メソッドで推論を行い、最後に encode_response メソッドで推論結果を整形 するという流れです。LitServe ではモデルの読み込みは setup 内で行うのが慣習のようです。

これを元にして、画像生成モデルを使った API を作成していきます。
AuraFlow での実装例 を参考に次のようになりました:

text2image.py
import argparse
from PIL import Image

from io import BytesIO
from fastapi.responses import Response
from pydantic import BaseModel, field_validator

import torch
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, AutoencoderKL

import litserve as ls


def prepare_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name",
        type=str,
        default="OnomaAIResearch/Illustrious-xl-early-release-v0",
    )
    parser.add_argument("--vae", type=str, default="madebyollin/sdxl-vae-fp16-fix")
    parser.add_argument("--port", type=int, default=8000)
    return parser.parse_args()


class GenerationParams(BaseModel):
    prompt: str
    negative_prompt: str = "bad quality, worst quality, lowres, bad anatomy, sketch, jpeg artifacts, ugly, poorly drawn, signature, watermark, bad anatomy, bad hands, bad feet, retro, old, 2000s, 2010s, 2011s, 2012s, 2013s, multiple views, screencap"
    inference_steps: int = 25
    cfg_scale: float = 6.5
    width: int = 768
    height: int = 1024

    @field_validator("width", "height")
    def check_divisible_by_64(cls, value):
        if value % 64 != 0:
            raise ValueError(f"{value} is not divisible by 64")
        return value


class T2IModel:
    def __init__(self, model_name: str, vae_name: str) -> None:
        vae = AutoencoderKL.from_pretrained(
            vae_name,
            torch_dtype=torch.float16,
        )
        pipe = DiffusionPipeline.from_pretrained(
            pretrained_model_name_or_path=model_name,
            vae=vae,
            torch_dtype=torch.float16,
            custom_pipeline="lpw_stable_diffusion_xl",
            add_watermarker=False,
        )
        pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
            pipe.scheduler.config
        )
        self.pipe = pipe

    def generate(
        self,
        params: GenerationParams,
    ):
        image = self.pipe(
            prompt=params.prompt,
            negative_prompt=params.negative_prompt,
            num_inference_steps=params.inference_steps,
            guidance_scale=params.cfg_scale,
            width=params.width,
            height=params.height,
            return_type="pil",
        ).images[0]  # type: ignore

        return image


class SimpleLitAPI(ls.LitAPI):
    def __init__(self, args):
        super().__init__()

        self.model_name = args.model_name
        self.vae_name = args.vae

    def setup(self, device):
        self.model = T2IModel(self.model_name, self.vae_name)
        self.model.pipe.to(device)

    def decode_request(self, request: dict):
        params = GenerationParams(**request)
        return params

    def predict(self, params: GenerationParams):
        image = self.model.generate(params)
        return image

    def encode_response(self, image: Image.Image):
        buffered = BytesIO()
        image.save(buffered, format="WEBP")

        return Response(
            content=buffered.getvalue(), headers={"Content-Type": "image/webp"}
        )


def main():
    args = prepare_args()
    server = ls.LitServer(SimpleLitAPI(args), accelerator="auto", max_batch_size=1)
    server.run(port=args.port)


if __name__ == "__main__":
    main()

diffusers のパイプラインを一度 T2IModel としてラップすることで SimpleLitAPI で呼び出すときに簡潔に書けるようにしています。引数の指定やパラメータのバリデーションがついたくらいで、基本的な流れは公式の例と変わりません。

気をつけるべき点は、デフォルトの encode_response は特に何もしないと JSON にパースしようとするため、画像を返したい場合は fastapiResponse に入れてから返す 必要があります。

サーバーの起動・生成

次のコマンドでサーバーを起動します:

python ./text2image.py

すると localhost:8000 でサーバーが立ち上がるので、https://localhost:8000/predict に POST リクエストを送ると画像が返ってきます。


Postmanから叩いた様子

ちゃんと画像が返ってきました!

おわりに

今回は LitServe を用いて画像生成サーバーを建てる方法を紹介しました。画像生成 WebUI に付属している既存の API サーバーはクライアントにとってあまり使いやすいものではないことが多いですが、LitServe を使ったミニマルな実装では 余計なものがついてこない ため、本質的でない部分に意識を割かれることなくクライアントからもシンプルに呼び出す ことができました。

LitServe は画像生成だけでなく他のタスクでも利用できるため、一般的な推論ツールでは対応していない処理推論サーバーが用意されていないモデルであっても、LitServe を使うことで手軽にサーバーを建てることができるので、機械学習モデルの推論サーバーを建てる際の選択肢の一つ になるのではないでしょうか。

今回のコード

https://github.com/p1atdev/sdxl-litserve

GitHubで編集を提案

Discussion