StreamlitとFastAPIで非同期推論MLアプリを作る
StreamlitはPythonだけでwebアプリを作ることができるツール(ライブラリ)です。フロントに関する知識がほとんど不要なため、簡単なダッシュボードやデモアプリを作るのに適しています。公式のページでは様々なサンプルアプリが公開されています。
ところで機械学習(特に深層学習)モデルでは、例えば画像1枚あたり数秒の推論時間がかかることもあります。Streamlitは機械学習のデモアプリ用途としても適していると思いますが、推論に時間がかかる場合にいちいち推論完了を待つのは退屈かもしれません。ここではPythonのwebフレームワークであるFastAPIを組み合わせることで、推論を非同期で行う画像認識アプリケーションを作ります。
コードはこちらに配置しました。
アプリ内容
StreamlitによるGUIは以下のようになります。画像をアップロードし、「Submit」ボタンを押すことで画像認識を行います。この時、FastAPI側ではCNNによる推論をバックグランドジョブとして登録し、画面には(推論完了を待たず)すぐに「◯ files are submitted」と表示されます。
「Refresh」ボタンを押すとバックエンドへ処理状況の問い合わせが行われ、推論が完了したファイル一覧が表示されます。本当は推論結果がどうなっているかまで画面に表示するべきですが、今回は技術検証ということで割愛してます。
FastAPI側ではCNNによって画像分類を行い、以下のような結果を画像として保存しています。
Streamlit(フロント)
Streamlit部分はシンプルです。st.file_uploader
を通じて画像をアップロードし、POSTメソッドを通じてバックエンドの/predict
APIに渡します。
推論完了リストは/results
APIに問い合わせることで取得します。レスポンスにはファイルパスのリストが格納されており、これをPandasのデータフレームとして整形してから画面に表示します。
st.button('Refresh')
は一見意味がないようですが、Streamlitではボタンやスライダーなどの状態が変化するたびにスクリプト全体を実行しなおします。これによってボタンが押されるたびに/results
APIへの問い合わせが行われ、推論完了リストが更新されます。
import os
import httpx
import streamlit as st
import pandas as pd
from typing import List
def format_results(result_files: List[str]) -> pd.DataFrame:
job_indices, filenames = [], []
for _, job_id, filename in map(lambda s: s.split('/'), result_files):
job_indices.append(job_id)
filenames.append(filename)
df = pd.DataFrame({'job_id': job_indices, 'filename': filenames})
return df
BACKEND_HOST = os.environ.get('BACKEND_HOST', '127.0.0.1:80')
image_files = st.file_uploader('Target image file',
type=['png', 'jpg'],
accept_multiple_files=True)
if len(image_files) > 0 and st.button('Submit'):
files = [('files', file) for file in image_files]
r = httpx.post(f'http://{BACKEND_HOST}/predict', files=files)
st.success(r.json())
if st.button('Refresh'):
st.success('Refreshed')
r = httpx.get(f'http://{BACKEND_HOST}/results')
df_results = format_results(r.json())
st.write(df_results)
Dockerfileは以下のようになっています。依存関係はPoetryで管理しました。
FROM python:3.8-slim
EXPOSE 8080
WORKDIR /app
RUN pip install poetry
COPY poetry.lock pyproject.toml .
RUN poetry config virtualenvs.create false \
&& poetry install --no-interaction --no-ansi
COPY . .
CMD ["streamlit", "run", "app.py", "--server.port", "8080"]
FastAPI(バックエンド)
バックエンドでは推論処理を受け付けるための/predict
APIと、推論完了リストを返すための/results
APIを作成しました。
/predict
APIは渡された画像に対して推論を行います。この時FastAPIのBackgroundTasks
を用いて推論処理を登録することで、推論の完了を待たずにAPIからのレスポンスを返すことができます。今回は簡単のためKerasのEfficientNetB0を用いており、ローカルから重みをロードしています。重みはGitHubリポジトリにはプッシュしていないため、再現実行するときには各自で用意してください。
/results
APIは推論の結果生成した画像ファイルの一覧を返します。
import os
import sys
import time
import asyncio
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List, Dict
from datetime import datetime
from fastapi import FastAPI, BackgroundTasks, File, UploadFile
from tensorflow.keras.applications.efficientnet import EfficientNetB0, decode_predictions
IMAGE_SIZE = (224, 224)
plt.switch_backend('Agg')
app = FastAPI()
model = EfficientNetB0(weights=None)
model.load_weights('weights/effnet-b0.ckpt')
def save_prediction(image: np.ndarray,
classes: List[str],
probs: List[float],
savepath: Path) -> None:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.set_title('Input image')
ax1.imshow(image)
ax2.set_title('Top probabilities')
ax2.barh(classes, probs)
ax2.invert_yaxis()
fig.tight_layout()
plt.savefig(savepath)
def predict_images(files: List[UploadFile], job_id: str) -> None:
savedir = Path(f'./results/{job_id}')
if not savedir.exists():
savedir.mkdir(parents=True)
for file in files:
image = tf.io.decode_image(file.file.read())
image = tf.image.resize_with_pad(image, *IMAGE_SIZE)
pred = model.predict(image[None])
pred = decode_predictions(pred)[0]
image = image.numpy().astype(np.uint8)
_, classes, probs = list(zip(*pred))
savepath = savedir/file.filename
save_prediction(image, classes, probs, savepath=savepath)
@app.post('/predict')
async def predict(files: List[UploadFile] = File(...),
background_tasks: BackgroundTasks = None):
job_id = datetime.now().strftime("%Y%m%d_%H%M%S")
background_tasks.add_task(predict_images, files=files, job_id=job_id)
return f'{len(files)} files are submitted'
@app.get('/results')
async def results():
p = Path('results')
# results/yyyymmdd_hhmmss/(png|jpg)
result_files = [str(pp) for pp in p.glob('*/*')]
return result_files
FastAPIは/docs
にアクセスすることで、各APIの仕様を確認することができます。実際にAPIを叩くこともでき便利です。
Dockerイメージに関しては、FastAPIは開発者の方がGunicornとUvicornを組み込んだベースイメージを公開しています。これを用いることでGunicorn(WSGI)によるマルチプロセスと、Uvicorn(ASGI)による非同期処理の恩恵を受けることができます。せっかくなので使ってみます。
FROM tiangolo/uvicorn-gunicorn-fastapi:python3.8-slim
RUN pip install poetry
COPY ./pyproject.toml ./poetry.lock* /app/
RUN poetry config virtualenvs.create false \
&& poetry install --no-interaction --no-ansi --no-root --no-dev
COPY . /app/
上記のベースイメージにはstart.sh
という起動スクリプトが保持されており、コンテナ起動時に対象のFastAPIモジュールを環境変数APP_MODULE
を通じて指定する必要があります。例えば今回はserver.py
のapp
モジュールが対象となるので、以下のようにコンテナを起動することができます。
# Build
$ docker image build ./fastapi-backend -t st-fastapi-example/backend:latest
# Run
$ docker container run -d --rm -p 80:80 -e APP_MODULE=server:app --name backend st-fastapi-example/backend:latest
Docker Compose
上記のフロント・バックエンドのコンテナをまとめて起動するComposeファイルは以下のようにしました。コンテナ間で通信するためにフロントの環境変数としてBACKEND_HOST
を設定しています。またバックエンドでは上記のAPP_MODULE
を設定しています。
version: '3.8'
services:
app:
build:
context: ./streamlit-front
dockerfile: Dockerfile
image: st-fastapi-example/app:latest
container_name: app
restart: unless-stopped
depends_on:
- backend
networks:
- st_fastapi_net
environment:
BACKEND_HOST: backend:80
ports:
- 8080:8080
backend:
build:
context: ./fastapi-backend
dockerfile: Dockerfile
image: st-fastapi-example/backend:latest
container_name: backend
restart: unless-stopped
networks:
- st_fastapi_net
ports:
- 80:80
environment:
APP_MODULE: server:app
networks:
st_fastapi_net:
driver: bridge
以上がアプリの構成です。$ docker compose up -d
から各コンテナを起動すれば、ローカルホストの8080番ポートからアプリに接続できるはずです。推論結果の画像はバックエンドのコンテナ内に保存されるので、実際に確認したい場合はマウントするなりコピーするなりしてみてください。
まとめ
HTMLやJavaScriptなどを触ることなく、Pythonのみでそこそこのアプリを構築できました。Streamlitはやはりデモアプリのプロトタイピングに最適であり、FastAPIもシンプルな触り心地のため初心者でもとっつきやすいと思います。
一方でStreamlitが提供しているフロント機能には制限もあります。例えば現段階では、セッション管理や多様な種類のファイルのダウンロード機能までは用意されていません。ただし開発者のTwitterによるとこの辺の機能も鋭意開発中とのことです。続報に期待しましょう。
次はGKE Autopilotへのデプロイとか試してみようと思います。その場合は解析結果をバックエンドのローカルディスクではなく、GCSに配置するなどの検討も必要そうです。
Discussion