🎯

AIモデルの予測結果の可視化を楽にしたい? Gradioを使ってみよう!

2023/07/31に公開

AIモデルによる推論結果を可視化して、それを社内に共有できるWebアプリケーションを作りたいということが、AIに関わるお仕事だとよくあるかと思います。
こういった機械学習向けのデモライブラリ (フレームワーク) として最近ちょこちょこ話題に上るGradioを触って、使い勝手などを見ていきます。

https://www.gradio.app/

Gradioの公式ガイドを主に参考にしています。

https://www.gradio.app/guides

Gradioとは

Gradioは、機械学習モデルのデモをWebアプリケーションとして簡単に実装・公開できるPythonライブラリです。

機械学習モデルを使ったプロダクトを開発・運用していると、自分が作った入力 (レコードデータや動画像など) を使って、それらのモデルを動かして、結果を可視化する、ということをサクッと行いたいことがしばしばあります。
アドホックな場合はJupyter Notebookで賄うこともありますが、モデルを動かす環境のセットアップだったり、自分以外のメンバーに試してもらったりなどはJupyter Notebookでは困難です。
これに対して、モデルをWebアプリケーションとして公開するというのがよくある解決策です。しかし、フロントエンドとバックエンドのそれぞれで実装が必要というWebアプリケーション開発一般のタスクに加えて、入力データの形式がテキストだけではなく動画像を含んでいたり、GPUなどのリソース制約で同時処理数が制限されるなど、機械学習モデル特有の面倒もあります。
類似ライブラリにあたるStreamlitやDashなどの可視化ツールと比較して、Gradioは上記のような課題を抱える機械学習モデルのデモを主たるターゲットにしています。入出力データの変換をいい感じに行なってくれたり (アップロードされた画像ファイルをNumPy配列に変換して引数に渡されるなど) 、Webカメラを使って画像データをストリーミングで渡したり、Hugging Faceなどを使ったサンプルアプリケーションが充実していたりと、手厚さが伺えます。

Streamlitとの比較では、こちらの投稿が参考になります。
https://aotamasaki.hatenablog.com/entry/gradio-explanation

前提環境とインストール

Gradio 3.37.0を使います。

$ python --version               
Python 3.11.3

$ cat requirements.txt                             
gradio==3.37.0

$ pip install -r requirements.txt

Interfaceを使ったシンプルUI実装

最低限のHelloアプリケーションを実装して動かしてみます。
デフォルトでは127.0.0.1に7860ポートで立ち上がります。

import gradio as gr

def hello(name: str) -> str:
    return f"Hello, {name}!"

hello_interface = gr.Interface(fn=hello, inputs="text", outputs="text")
hello_interface.launch()
$ python app.py
Running on local URL:  http://127.0.0.1:7860
To create a public link, set `share=True` in `launch()`.

ブラウザでアクセスすると、以下のようなシンプルなWebページが表示されます。

Alt text

nameにテキストを入力して送信をクリックすると、outputにテキストが表示されます。

Alt text

Interface

コードで肝になるのがInterfaceクラスで、これを使うことでWebアプリケーションのGUIを高い抽象度で実装できます。

hello_interface = gr.Interface(fn=hello, inputs="text", outputs="text")

Interfaceクラスのコンストラクタで、実行する関数 (fn) 、入力 (inputs) 、出力 (outputs) を指定します。
レイアウトはinputs、outputsのそれぞれで記述された順番でUIコンポネントが配置されます。

関数fnに制約はなく、デモの場合、機械学習モデルによる推論などの処理のエントリポイントとなります。
入力がinputsでの指定にあわせて適切に変換された後に引数として渡され、さらに返り値もoutputsでの指定にあわせて変換されてブラウザ上で表示されます。
inputs/outputsにて文字列で指定できるUIは"text"以外にも、"button"、"radio"、"iamge"、"video"、"file"などがあります。

Interfaceのlaunch()メソッドを呼び出すと、Webアプリケーションサーバが立ち上がります。

hello_interface.launch()

ただし、Interfaceでは細かいレイアウトやボタンのリスナなどの挙動をカスタマイズできません。
Gradioでは、こうしたカスタマイズを行うために、Blocksという仕組みもあります。(後述)

IOComponent

inputs (とoutputs) はリストで指定できます。

def calculate(operator: str, x: int, y: int) -> float:
    match operator:
        case "+":
            return x + y
        case "/":
            return x / y

calculate_interface = gr.Interface(
    fn=calculate, inputs=["text", "text", "text"], outputs=["text"]
)
calculate_interface.launch()

ところがこれではinputsがtextなので、xとyはstrで受けてしまって "10"+"2"="102" で出力されてしまいます。また、operatorに指定可能な値や、xとyの意味について、UI上での説明が不十分です。

Alt text

GradioではTextbox、Number、Imageなど、IOComponentを継承したUIコンポネントクラスがgradioパッケージ直下に多数用意されています。Interfaceのinputsやoutputsでこれらの継承クラスを使うことで、細かいUIの調整ができます。
以下では、Dropdownによる演算子の選択、Numberによる入力を行なっています。

calculate_interface = gr.Interface(
    fn=calculate,
    inputs=[
        gr.Dropdown(["+", "/"]),
        gr.Number(0, info="Left operand"),
        gr.Number(0, info="Right operand"),
    ],
    outputs=gr.Number(0.0),
)
calculate_interface.launch()

Alt text

examples

examples引数にリストで入力例を渡すこともできます。

calculate_interface = gr.Interface(
    fn=calculate_with_error,
    inputs=[
        gr.Dropdown(["+", "/"]),
        gr.Number(0, info="Left operand"),
        gr.Number(0, info="Right operand"),
    ],
    outputs=gr.Number(0.0),
    examples=[["+", 10, 2], ["/", 100, 3]],
)

画面上に表形式で表示され、選択すると入力コンポネントに反映されます。

Alt text

Image

画像のコンポネントは "image" で指定できます。inputsに "image" を指定すると画像のアップロードフォームが作られ、outputsに "image" を指定すると画像が表示されます。

デフォルトではnumpy.arrayに変換されますが、type="pil"とするとPIL.Image、type="filepath"とするとファイルパス (strやpathlib.Path) で受け渡すこともできます。

def rotate(image: np.array) -> np.array:
    return np.rot90(image)

rotate_interface = gr.Interface(
    fn=rotate, inputs=gr.Image(shape=(200, 200)), outputs="image"
)
rotate_interface.launch()

アップロードされた画像は縦横200pxでクロップとリサイズがされ、アップロード後に回転された結果が表示されます。

Alt text

Error

関数内でのエラー処理は gr.Error("error message") とすると、メッセージをポップアップで表示してくれます。

def calculate_with_error(operator: str, x: int, y: int) -> float:
    match operator:
        case "+":
            return x + y
        case "/":
            if y == 0:
                raise gr.Error("Denominator is zero.")
            return x / y

Alt text

Flag

送信とクリア以外に、"フラグする"("Flag")というボタンが表示されています。
これはモデルの誤った振る舞いや気になった振る舞いを簡易的にサーバサイドで記録してくれる仕組みで、このボタンがクリックされると入出力情報をCSVファイルとして出力してくれます。
出力先のディレクトリは、Interfaceのflagging_dir引数 (デフォルトでは、カレントディレクトリにflaggedで作られます) で設定可能です。

最初のhelloの例の場合、デフォルトでは以下のようなCSVが出力されます。

name,output,flag,username,timestamp
ohke,"Hello, ohke!!",,,2023-07-31 07:47:22.416613
tanaka,"Hello, tanaka!!",,,2023-07-31 07:47:35.370061

gradioコマンド

pip install gradioによって、gradioでコマンドも同時にインストールされます。
python app.pyの代わりにgradio app.pyで実行すると、スクリプトの保存と同時にリロードされるようになります。

$ gradio app.py
Launching in *reload mode* on: http://127.0.0.1:7860 (Press CTRL+C to quit)
Watching: '/Users/kenta.onishi/private/python-gradio/.venv/lib/python3.11/site-packages/gradio', '/Users/kenta.onishi/private/python-gradio'
Running on local URL:  http://127.0.0.1:7861
To create a public link, set `share=True` in `launch()`.

# ここでapp.pyを保存...
WARNING:  StatReload detected changes in 'app.py'. Reloading...
Running on local URL:  http://127.0.0.1:7861
To create a public link, set `share=True` in `launch()`.

BlocksでUIのカスタマイズ実装

1アクションで済むような簡単なUIであればInterfaceで十分ですが、画面構成をカスタマイズしたり、関数の出力を別の関数の入力にしたりといったことはできません。
Blocksは、UIコンポネントを組み合わせて、より複雑な画面構成やデータフローを実現するための仕組みです。

Blocksはwith節内 with gr.Blocks() as blocks: で作成します。
with節内で作成されたUIコンポネントが自動的に追加・表示されます。デフォルトでは縦に並べられます。
Interfaceでは隠蔽されていたButtonを作成し、clickメソッドで実行する関数と、先に作成した入力コンポネントと出力コンポネントを紐づけます。

with gr.Blocks() as hello_blocks_interface:
    input = gr.Textbox(label="Name")
    output = gr.Textbox(label="Output Box")
    button = gr.Button("Hello")
    button.click(fn=hello, inputs=input, outputs=output, api_name="hello")

hello_blocks_interface.launch()

レイアウト (Tab, Row, Col)

Tab()でタブ分割、Row()で横並び、Col()で縦並びにできます。
さらにこれらをwith節で入れ子にできるので、複雑なレイアウトも実装できます。

with gr.Blocks() as calculator_blocks:
    with gr.Tab("Sqrt"):
        x = gr.Number(label="x")
        sqrt_button = gr.Button("Calculate")
    with gr.Tab("Add"):
        with gr.Row():
            a = gr.Number(label="a")
            b = gr.Number(label="b")
        add_button = gr.Button("Calculate")

    output = gr.Number(label="Result")

    sqrt_button.click(lambda x: x**0.5, inputs=x, outputs=output)
    add_button.click(lambda a, b: a + b, inputs=[a, b], outputs=output)

calculator_blocks.launch()

Alt text

出力を別のUIコンポネントの入力にする

UIコンポネントの出力を、別のUIコンポネントの入力に渡すこともできます。
以下の例では、xを入力として受け取ってyを出力し、yを入力として受け取ってzを出力しています。

with gr.Blocks() as multi_step_blocks:
    x = gr.Number(label="x")
    y = gr.Number(label="y")
    z = gr.Number(label="z")

    x_plus_1_button = gr.Button("x + 1")
    x_plus_1_button.click(fn=lambda x: x + 1, inputs=x, outputs=y)

    y_plus_1_button = gr.Button("y + 1")
    y_plus_1_button.click(fn=lambda y: y + 1, inputs=y, outputs=z)

multi_step_blocks.launch()

Alt text

click以外のイベントリスナ

Button以外のUIコンポネントもイベントトリガとして利用できます。
UIコンポネントによってイベントリスナの設定メソッドは異なります。TextboxやNumberの場合、値の変更時にトリガされるchange()でセットします。

先の例をchange()を使って書き換えると以下のようになります。xを変更すると、yもzも同時に変更されます。

with gr.Blocks() as multi_step_blocks:
    x = gr.Number(label="x")
    y = gr.Number(label="y")
    z = gr.Number(label="z")

    x.change(fn=lambda x: x + 1, inputs=x, outputs=y)
    y.change(fn=lambda y: y + 1, inputs=y, outputs=z)

multi_step_blocks.launch()

各UIコンポネントが提供するイベントの一覧は Gradio Component Docs を参照ください。

複数の入力

2つ以上の入力引数の場合、リストで渡すことができます。

with gr.Blocks() as multi_inputs_blocks:
    x = gr.Number(label="x")
    y = gr.Number(label="y")
    z = gr.Number(label="z")

    x_plus_y_button = gr.Button("x + y")
    x_plus_y_button.click(lambda x, y: x + y, inputs=[x, y], outputs=z)

リスト以外のオプションとして、{x, y}とするとdictで渡すこともできます。

x_plus_y_button.click(lambda data: data[x] + data[y], inputs={x, y}, outputs=z)

update()によるUIのインタラクティブな更新

入力値に応じてUIコンポネントのプロパティを更新したい場合、update()関数を使います。
以下はvalueとinfoのプロパティを更新する例です。

def random_number(seed: float):
    print(seed)
    np.random.seed(int(seed))
    n = np.random.random()
    return gr.update(value=str(n), info="<0.5" if n < 0.5 else ">=0.5")

with gr.Blocks() as random_number_blocks:
    seed = gr.Number(label="seed")
    output = gr.Textbox(label="random number")
    seed.change(fn=random_number, inputs=seed, outputs=output)

random_number_blocks.launch()

Alt text

Session state

変数をfnのパラメータで渡す関数の外で定義すると、関数内からもその変数を参照できます。ロードに時間のかかるモデルなどは、グローバル変数で参照できるようにしておくことで、毎回のロードを防げます。

以下ではgreeting変数をグローバル変数として定義しています。このgreetingは全ユーザで共有されます。

greeting = "Hello"

def hello(name: str) -> str:
    return f"{greeting} {name}!"

hello_interface = gr.Interface(fn=hello, inputs="text", outputs="text")
hello_interface.launch()

Gradioではページセッション内で保持される状態 (Session state) もサポートしています。
gr.State()でセッション内で共有する状態変数 (以下ではhistory) を作成・初期化して、クリックイベントリスナーに渡し、さらに更新した状態変数を返り値で戻しています。
historyに入力されたmessageがリストで残り、message_countではその件数を表示させています。

def echo(message: str, history: list[str]):
    history.append(message)
    return message, len(history), history

with gr.Blocks() as echo_blocks:
    history = gr.State([])

    message = gr.Textbox(label="Message")
    output = gr.Textbox(label="Output")
    message_count = gr.Number(label="Message count")
    button = gr.Button("Echo")

    button.click(
        fn=echo, inputs=[message, history], outputs=[output, message_count, history]
    )

echo_blocks.launch()

Webアプリケーションとしての機能

認証

launch()のauth引数にユーザ名とパスワードを受け取ってboolを返す関数を渡すことで、認証画面を表示できます。
以下の例では、パスワードがユーザ名の逆順の場合は認証成功とする関数を渡しています。

def authenticate(username: str, password: str) -> bool:
    return password == username[::-1]

demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch(auth=auth)

アクセスすると、以下のような認証画面が表示されます。認証成功すると画面遷移できると同時に、クッキーにアクセストークンが保持されます。

Alt text

Queueing

デフォルトでは同時リクエスト数などの制限は行われません。
InterfaceやBlocksのqueue()を呼ぶことで、同時リクエスト数を制限できます。引数concurrency_countが制限数で、デフォルトは1です。

import time

def long_running_hello(x):
    time.sleep(30)
    return f"Hello, {x}!"

long_running_hello_interface = gr.Interface(
    fn=long_running_hello, inputs="text", outputs="text"
)
long_running_hello_interface.queue()
long_running_hello_interface.launch()

処理中 (上) はprocessingですが、待っている間 (下) はqueue と何番目なのかが表示されます。

Alt text
Alt text

イテレーティブな結果の表示

関数fnの結果をyieldで返すようにした場合、イテレーティブに結果を表示してくれます。
以下のコードでstepsに3を入力した場合、Starting -> 0 -> 1 -> 2 -> Done!が1秒おきで表示されます

def iterative(steps: float) -> str:
    yield "Starting"
    for i in range(int(steps)):
        time.sleep(1)
        yield str(i)
    yield "Done!"

demo = gr.Interface(iterative, inputs=gr.Number(0), outputs="text")
demo.queue().launch()

プログレスバー

キューを有効にした場合のみ、画面上にプログレスバーを表示できます。

関数の入力引数の後ろに、デフォルト値をgr.Progress()とした引数を追加します。
progress()で進捗を0.0~1.0で指定するか、または、ループでprogress.tqdm(range(steps))としてtqdmライクに記述できます。

def hello_progress(name: str, progress=gr.Progress()) -> str:
    progress(0.0, desc="Starting...")
    time.sleep(1)
    for _ in progress.tqdm(range(100)):
        time.sleep(0.1)
    return f"Hello, {name}!!"

hello_progress_interface = gr.Interface(
    fn=hello_progress, inputs="text", outputs="text"
)
hello_progress_interface.queue().launch()

Alt text

API

GradioのInterfaceやBlocksでは、ブラウザからアクセスするためのWebページ以外にも、デフォルトでWeb APIを提供します。
Blocksの最初の例を使います。

with gr.Blocks() as hello_blocks_interface:
    input = gr.Textbox(label="Name")
    output = gr.Textbox(label="Output Box")
    button = gr.Button("Hello")
    button.click(fn=hello, inputs=input, outputs=output, api_name="hello")

hello_blocks_interface.launch()

python app.py (または gradio app.py) で起動してページを表示すると、下の方にUse via APIというリンクが現れ、リンク先にアクセス方法が記載されています。
gradio_clientでリクエストでき、api_name="hello"で指定している場合はapi_name="/hello"で指定します。Interfaceの場合は、"/predict" となります。
.launch(show_api=False)とすると、Web APIは公開されません。

Alt text
Alt text

inputsもoutputsもtextなので、strでやり取りできます。

>>> from gradio_client import Client
>>> client = Client("http://127.0.0.1:7861/")
Loaded as API: http://127.0.0.1:7861/>>> result = client.predict("ohke", api_name="/hello")
>>> print(result)
Hello, ohke!!

リクエスト情報

関数の入力引数の後ろに、gr.Request型の変数を追加すると、リクエストヘッダーやクッキーなどのリクエスト情報にアクセスできます。

def hello_request(name: str, request: gr.Request):
    if request:
        print(f"{request.request.headers=}")
        print(f"{request.client.host=}")
    return "Hello, " + name + "!"

hello_request_interface = gr.Interface(hello_request, inputs="text", outputs="text")
hello_request_interface.launch()
# request.request.headers=Headers({'host': 'localhost:7861', 'connection': 'keep-alive', 'content-length': '77', 'sec-ch-ua': '"Not.A/Brand";v="8", "Chromium";v="114", "Brave";v="114"', 'sec-ch-ua-platform': '"macOS"', 'sec-ch-ua-mobile': '?0', 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36', 'content-type': 'application/json', 'accept': '*/*', 'sec-gpc': '1', 'accept-language': 'ja', 'origin': 'http://localhost:7861', 'sec-fetch-site': 'same-origin', 'sec-fetch-mode': 'cors', 'sec-fetch-dest': 'empty', 'referer': 'http://localhost:7861/', 'accept-encoding': 'gzip, deflate, br', 'cookie': 'access-token=6zHJL_LAwPCH0J8vReBMZw; access-token-unsecure=6zHJL_LAwPCH0J8vReBMZw'})
# request.client.host='127.0.0.1'

まとめ

Gradioを試してみた所感を最後にまとめます。

まず入出力のUIコンポネントが一通り揃っていて、入力・推論・出力だけのシンプルなデモアプリケーションを作るのに困ることはなさそうです。多少複雑になっても、BlocksでTabやStateなどを組み合わせればなんとかなりそうです。
動画像もGradio側でNumPyやPillowの形式に変換してくれたり、イテレーティブな出力やストリーム処理もサポートしていたりするので、推論を実行する前処理・後処理のコードが膨れ上がる、といったことは回避しやすくなっています。
AIモデルの推論と可視化に特化しているだけあって、シンプルなデモアプリケーションであれば簡単に作れそうです。

また、キューの仕組みによって、同時処理数を制御できるのも嬉しい点です。
処理中のリクエスト数や進捗状態 (プログレスバー) がUIに表示しやすくなっていて、ユーザにとっても親切です。

その反面、複数ページをまとめたり(Streamlitで言うmultipage apps)、ダッシュボードを作ったりなど、綺麗に可視化するという点においては弱いという印象です。
また、1つのデモ(画面)あたり1つのWebアプリケーションサーバが立つことになるので、デプロイや運用面ではタスクや考慮ポイントが増えます。

総じて、AIモデルのデモや動作確認などのユースケースであれば、Gradioはデフォルトの選択肢として良さそうです。

参考

Gradio - Quickstart
Gradio - Docs
【Streamlitよりいいかも?】機械学習系のデモアプリ作成に最適!Gradio解説 - 学習する天然ニューラルネット

Discussion