💫

StreamlitとFastAPIで非同期推論MLアプリを作る

2021/06/20に公開

StreamlitはPythonだけでwebアプリを作ることができるツール(ライブラリ)です。フロントに関する知識がほとんど不要なため、簡単なダッシュボードやデモアプリを作るのに適しています。公式のページでは様々なサンプルアプリが公開されています。

https://streamlit.io/

ところで機械学習(特に深層学習)モデルでは、例えば画像1枚あたり数秒の推論時間がかかることもあります。Streamlitは機械学習のデモアプリ用途としても適していると思いますが、推論に時間がかかる場合にいちいち推論完了を待つのは退屈かもしれません。ここではPythonのwebフレームワークであるFastAPIを組み合わせることで、推論を非同期で行う画像認識アプリケーションを作ります。

https://fastapi.tiangolo.com/

コードはこちらに配置しました。

https://github.com/daigo0927/blog/tree/master/streamlit-fastapi-example

アプリ内容

StreamlitによるGUIは以下のようになります。画像をアップロードし、「Submit」ボタンを押すことで画像認識を行います。この時、FastAPI側ではCNNによる推論をバックグランドジョブとして登録し、画面には(推論完了を待たず)すぐに「◯ files are submitted」と表示されます。

「Refresh」ボタンを押すとバックエンドへ処理状況の問い合わせが行われ、推論が完了したファイル一覧が表示されます。本当は推論結果がどうなっているかまで画面に表示するべきですが、今回は技術検証ということで割愛してます。

FastAPI側ではCNNによって画像分類を行い、以下のような結果を画像として保存しています。

Streamlit(フロント)

Streamlit部分はシンプルです。st.file_uploaderを通じて画像をアップロードし、POSTメソッドを通じてバックエンドの/predictAPIに渡します。

推論完了リストは/resultsAPIに問い合わせることで取得します。レスポンスにはファイルパスのリストが格納されており、これをPandasのデータフレームとして整形してから画面に表示します。

st.button('Refresh')は一見意味がないようですが、Streamlitではボタンやスライダーなどの状態が変化するたびにスクリプト全体を実行しなおします。これによってボタンが押されるたびに/resultsAPIへの問い合わせが行われ、推論完了リストが更新されます。

streamlit-front/app.py
import os
import httpx
import streamlit as st
import pandas as pd
from typing import List


def format_results(result_files: List[str]) -> pd.DataFrame:
    job_indices, filenames = [], []
    for _, job_id, filename in map(lambda s: s.split('/'), result_files):
        job_indices.append(job_id)
        filenames.append(filename)
    df = pd.DataFrame({'job_id': job_indices, 'filename': filenames})
    return df


BACKEND_HOST = os.environ.get('BACKEND_HOST', '127.0.0.1:80')


image_files = st.file_uploader('Target image file',
                               type=['png', 'jpg'],
                               accept_multiple_files=True)

if len(image_files) > 0 and st.button('Submit'):
    files = [('files', file) for file in image_files]

    r = httpx.post(f'http://{BACKEND_HOST}/predict', files=files)
    st.success(r.json())


if st.button('Refresh'):
    st.success('Refreshed')
    
r = httpx.get(f'http://{BACKEND_HOST}/results')
df_results = format_results(r.json())
st.write(df_results)

Dockerfileは以下のようになっています。依存関係はPoetryで管理しました。

streamlit-front/Dockerfile
FROM python:3.8-slim
EXPOSE 8080
WORKDIR /app

RUN pip install poetry
COPY poetry.lock pyproject.toml .
RUN poetry config virtualenvs.create false \
    && poetry install --no-interaction --no-ansi

COPY . .
CMD ["streamlit", "run", "app.py", "--server.port", "8080"]

FastAPI(バックエンド)

バックエンドでは推論処理を受け付けるための/predictAPIと、推論完了リストを返すための/resultsAPIを作成しました。

/predictAPIは渡された画像に対して推論を行います。この時FastAPIのBackgroundTasksを用いて推論処理を登録することで、推論の完了を待たずにAPIからのレスポンスを返すことができます。今回は簡単のためKerasのEfficientNetB0を用いており、ローカルから重みをロードしています。重みはGitHubリポジトリにはプッシュしていないため、再現実行するときには各自で用意してください。

/resultsAPIは推論の結果生成した画像ファイルの一覧を返します。

fastapi-backend/server.py
import os
import sys
import time
import asyncio
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List, Dict
from datetime import datetime
from fastapi import FastAPI, BackgroundTasks, File, UploadFile
from tensorflow.keras.applications.efficientnet import EfficientNetB0, decode_predictions


IMAGE_SIZE = (224, 224)

plt.switch_backend('Agg')

app = FastAPI()

model = EfficientNetB0(weights=None)
model.load_weights('weights/effnet-b0.ckpt')


def save_prediction(image: np.ndarray,
                    classes: List[str],
                    probs: List[float],
                    savepath: Path) -> None:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    ax1.set_title('Input image')
    ax1.imshow(image)
    ax2.set_title('Top probabilities')
    ax2.barh(classes, probs)
    ax2.invert_yaxis()
    fig.tight_layout()
    plt.savefig(savepath)


def predict_images(files: List[UploadFile], job_id: str) -> None:
    savedir = Path(f'./results/{job_id}')
    if not savedir.exists():
        savedir.mkdir(parents=True)
    
    for file in files:
        image = tf.io.decode_image(file.file.read())
        image = tf.image.resize_with_pad(image, *IMAGE_SIZE)
        pred = model.predict(image[None])
        pred = decode_predictions(pred)[0]

        image = image.numpy().astype(np.uint8)
        _, classes, probs = list(zip(*pred))

        savepath = savedir/file.filename
        save_prediction(image, classes, probs, savepath=savepath)


@app.post('/predict')
async def predict(files: List[UploadFile] = File(...),
                  background_tasks: BackgroundTasks = None):
    job_id = datetime.now().strftime("%Y%m%d_%H%M%S")
    background_tasks.add_task(predict_images, files=files, job_id=job_id)
    return f'{len(files)} files are submitted'


@app.get('/results')
async def results():
    p = Path('results')
    # results/yyyymmdd_hhmmss/(png|jpg)
    result_files = [str(pp) for pp in p.glob('*/*')]
    return result_files

FastAPIは/docsにアクセスすることで、各APIの仕様を確認することができます。実際にAPIを叩くこともでき便利です。

Dockerイメージに関しては、FastAPIは開発者の方がGunicornとUvicornを組み込んだベースイメージを公開しています。これを用いることでGunicorn(WSGI)によるマルチプロセスと、Uvicorn(ASGI)による非同期処理の恩恵を受けることができます。せっかくなので使ってみます。

https://github.com/tiangolo/uvicorn-gunicorn-fastapi-docker

fastapi-backend/Dockerfile
FROM tiangolo/uvicorn-gunicorn-fastapi:python3.8-slim

RUN pip install poetry
COPY ./pyproject.toml ./poetry.lock* /app/
RUN poetry config virtualenvs.create false \
    && poetry install --no-interaction --no-ansi --no-root --no-dev

COPY . /app/

上記のベースイメージにはstart.shという起動スクリプトが保持されており、コンテナ起動時に対象のFastAPIモジュールを環境変数APP_MODULEを通じて指定する必要があります。例えば今回はserver.pyappモジュールが対象となるので、以下のようにコンテナを起動することができます。

# Build
$  docker image build ./fastapi-backend -t st-fastapi-example/backend:latest
# Run
$ docker container run -d --rm -p 80:80 -e APP_MODULE=server:app --name backend st-fastapi-example/backend:latest

Docker Compose

上記のフロント・バックエンドのコンテナをまとめて起動するComposeファイルは以下のようにしました。コンテナ間で通信するためにフロントの環境変数としてBACKEND_HOSTを設定しています。またバックエンドでは上記のAPP_MODULEを設定しています。

docker-compose.yaml
version: '3.8'
services:
  app:
    build:
      context: ./streamlit-front
      dockerfile: Dockerfile
    image: st-fastapi-example/app:latest
    container_name: app
    restart: unless-stopped
    depends_on:
      - backend
    networks:
      - st_fastapi_net
    environment:
      BACKEND_HOST: backend:80
    ports:
      - 8080:8080

  backend:
    build:
      context: ./fastapi-backend
      dockerfile: Dockerfile
    image: st-fastapi-example/backend:latest
    container_name: backend
    restart: unless-stopped
    networks:
      - st_fastapi_net
    ports:
      - 80:80
    environment:
      APP_MODULE: server:app

networks:
  st_fastapi_net:
    driver: bridge

以上がアプリの構成です。$ docker compose up -dから各コンテナを起動すれば、ローカルホストの8080番ポートからアプリに接続できるはずです。推論結果の画像はバックエンドのコンテナ内に保存されるので、実際に確認したい場合はマウントするなりコピーするなりしてみてください。

まとめ

HTMLやJavaScriptなどを触ることなく、Pythonのみでそこそこのアプリを構築できました。Streamlitはやはりデモアプリのプロトタイピングに最適であり、FastAPIもシンプルな触り心地のため初心者でもとっつきやすいと思います。

一方でStreamlitが提供しているフロント機能には制限もあります。例えば現段階では、セッション管理や多様な種類のファイルのダウンロード機能までは用意されていません。ただし開発者のTwitterによるとこの辺の機能も鋭意開発中とのことです。続報に期待しましょう。

次はGKE Autopilotへのデプロイとか試してみようと思います。その場合は解析結果をバックエンドのローカルディスクではなく、GCSに配置するなどの検討も必要そうです。

Discussion