LitServe で画像生成サーバーを建てる
この記事は(あんまりLLMに関係ないですが...) LLM・LLM活用 Advent Calendar 2024 シリーズ2 11日目の記事です。
はじめに
近年、テキスト生成や画像生成などの様々な機械学習モデルが登場しています。テキスト生成はその大規模な流行から、vLLM に代表されるような OpenAI API と互換性のある使いやすい推論サーバーや推論のためのライブラリが整っています。
しかし、画像生成では手軽に推論サーバーを建てるためのツールは整備されていないように感じます。AUTOMATIC1111 の WebUI や ComfyUI などの WebUI のついた推論ツールには、HTTP API を利用できる機能があることがありますが、 非常に多機能 であるため、使いこなすためには 一定の学習コスト がかかります。
そこで今回は、LitServe を用いて シンプルな画像生成サーバー を建てる方法を紹介します。
LitServe とは何ですの?
LitServe は、Lightning AI が開発している機械学習モデルの推論サーバーを簡単に建てるためのツールです。
単純にサーバーを建てるだけでなく、オートスケーリングや認証の機能もつけられるそうですが、今回は深いところまでは触らないです。
このライブラリ自体は画像生成専用というわけではなく、もちろんテキスト生成モデルでも使えますし、音声認識など他のタスクでも使うことができます。詳細は公式の Examples を参照してください:
今回の目標
今回は、LitServe と diffusers を使って必要最低限の機能を持った画像生成サーバーを作成します。画像生成モデルとしては、OnomaAIResearch/Illustrious-xl-early-release-v0 を使用します。
完成品
完成品のコードは以下のレポにあります:
使用したモデル:
プロジェクト作成
今回はPython のバージョン管理や仮想環境管理として uv を使いますが各自好きなものを使って良いです。
torch
や diffusers
等、画像生成に必要案ものを事前にインストールしておきます。
litserve
をインストールするには以下のコマンドを実行します:
uv add litserve
最終的な依存関係は以下の通りです:
pyproject.toml
[project]
name = "sdxl-litserve"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11,<3.12"
dependencies = [
"torch<2.5,>=2.4",
"torchvision>=0.19.0",
"safetensors>=0.4.5",
"hf-transfer>=0.1.8",
"tqdm>=4.67.1",
"transformers>=4.47.0",
"diffusers>=0.31.0",
"litserve>=0.2.5",
"accelerate>=1.2.0",
"pydantic>=2.10.3",
]
[tool.uv]
# これはフォーマッター
dev-dependencies = ["ruff>=0.8.0"]
[tool.uv.sources]
torch = [{ index = "pytorch-cu124", marker = "platform_system != 'Darwin'" }]
torchvision = [
{ index = "pytorch-cu124", marker = "platform_system != 'Darwin'" },
]
[[tool.uv.index]]
name = "pytorch-cu124"
url = "https://download.pytorch.org/whl/cu124"
explicit = true
今回はリクエストパラメータのバリデーションで pydantic
を使います。
実装
公式の一番シンプルな例では次のようなコードが示されています。
import litserve as ls
class SimpleLitAPI(ls.LitAPI):
def setup(self, device):
self.model1 = lambda x: x**2
self.model2 = lambda x: x**3
def decode_request(self, request):
return request["input"]
def predict(self, x):
squared = self.model1(x)
cubed = self.model2(x)
output = squared + cubed
return {"output": output}
def encode_response(self, output):
return {"output": output}
if __name__ == "__main__":
api = SimpleLitAPI()
server = ls.LitServer(api, accelerator="gpu")
server.run(port=8000)
ls.LitAPI
を継承して、setup
メソッドでモデルを初期化 し、decode_request
メソッドで POST リクエストの body をパース し、predict
メソッドで推論を行い、最後に encode_response
メソッドで推論結果を整形 するという流れです。LitServe ではモデルの読み込みは setup
内で行うのが慣習のようです。
これを元にして、画像生成モデルを使った API を作成していきます。
AuraFlow での実装例 を参考に次のようになりました:
import argparse
from PIL import Image
from io import BytesIO
from fastapi.responses import Response
from pydantic import BaseModel, field_validator
import torch
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, AutoencoderKL
import litserve as ls
def prepare_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="OnomaAIResearch/Illustrious-xl-early-release-v0",
)
parser.add_argument("--vae", type=str, default="madebyollin/sdxl-vae-fp16-fix")
parser.add_argument("--port", type=int, default=8000)
return parser.parse_args()
class GenerationParams(BaseModel):
prompt: str
negative_prompt: str = "bad quality, worst quality, lowres, bad anatomy, sketch, jpeg artifacts, ugly, poorly drawn, signature, watermark, bad anatomy, bad hands, bad feet, retro, old, 2000s, 2010s, 2011s, 2012s, 2013s, multiple views, screencap"
inference_steps: int = 25
cfg_scale: float = 6.5
width: int = 768
height: int = 1024
@field_validator("width", "height")
def check_divisible_by_64(cls, value):
if value % 64 != 0:
raise ValueError(f"{value} is not divisible by 64")
return value
class T2IModel:
def __init__(self, model_name: str, vae_name: str) -> None:
vae = AutoencoderKL.from_pretrained(
vae_name,
torch_dtype=torch.float16,
)
pipe = DiffusionPipeline.from_pretrained(
pretrained_model_name_or_path=model_name,
vae=vae,
torch_dtype=torch.float16,
custom_pipeline="lpw_stable_diffusion_xl",
add_watermarker=False,
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config
)
self.pipe = pipe
def generate(
self,
params: GenerationParams,
):
image = self.pipe(
prompt=params.prompt,
negative_prompt=params.negative_prompt,
num_inference_steps=params.inference_steps,
guidance_scale=params.cfg_scale,
width=params.width,
height=params.height,
return_type="pil",
).images[0] # type: ignore
return image
class SimpleLitAPI(ls.LitAPI):
def __init__(self, args):
super().__init__()
self.model_name = args.model_name
self.vae_name = args.vae
def setup(self, device):
self.model = T2IModel(self.model_name, self.vae_name)
self.model.pipe.to(device)
def decode_request(self, request: dict):
params = GenerationParams(**request)
return params
def predict(self, params: GenerationParams):
image = self.model.generate(params)
return image
def encode_response(self, image: Image.Image):
buffered = BytesIO()
image.save(buffered, format="WEBP")
return Response(
content=buffered.getvalue(), headers={"Content-Type": "image/webp"}
)
def main():
args = prepare_args()
server = ls.LitServer(SimpleLitAPI(args), accelerator="auto", max_batch_size=1)
server.run(port=args.port)
if __name__ == "__main__":
main()
diffusers
のパイプラインを一度 T2IModel
としてラップすることで SimpleLitAPI
で呼び出すときに簡潔に書けるようにしています。引数の指定やパラメータのバリデーションがついたくらいで、基本的な流れは公式の例と変わりません。
気をつけるべき点は、デフォルトの encode_response
は特に何もしないと JSON にパースしようとするため、画像を返したい場合は fastapi
の Response
に入れてから返す 必要があります。
サーバーの起動・生成
次のコマンドでサーバーを起動します:
python ./text2image.py
すると localhost:8000
でサーバーが立ち上がるので、https://localhost:8000/predict
に POST リクエストを送ると画像が返ってきます。
Postmanから叩いた様子
ちゃんと画像が返ってきました!
おわりに
今回は LitServe を用いて画像生成サーバーを建てる方法を紹介しました。画像生成 WebUI に付属している既存の API サーバーはクライアントにとってあまり使いやすいものではないことが多いですが、LitServe を使ったミニマルな実装では 余計なものがついてこない ため、本質的でない部分に意識を割かれることなくクライアントからもシンプルに呼び出す ことができました。
LitServe は画像生成だけでなく他のタスクでも利用できるため、一般的な推論ツールでは対応していない処理 や 推論サーバーが用意されていないモデルであっても、LitServe を使うことで手軽にサーバーを建てることができるので、機械学習モデルの推論サーバーを建てる際の選択肢の一つ になるのではないでしょうか。
今回のコード
Discussion