gradioでつくる簡易ダッシュボード
はじめに
みなさん。Splatoon3遊んでいますか。私は昨年9月の発売以来、かなり熱心に遊んでいます[1]。単に遊ぶだけでなく、なるべく勝てるように、つまり上達できるように努力しています。
最近は試合の結果を記録して分析し、上達に役立てたいと考えるようになりました。スプラトゥーンというゲームをご存じない方のために簡単に説明すると、スプラトゥーンにはナワバリバトルを含め全5種のルールがあり、ステージは2023年12月時点で全22種類存在します。
それぞれのステージには、防衛の要となる箇所や攻めの起点を作るために抑えたいキーポイントがあり、勝率を上げるためにはこうしたステージごとの戦略の研究は欠かせません。しかし全ステージについて研究するのは大変なので、なるべく勝率の悪いステージから対策したいところです。
任天堂が提供する公式アプリ「イカリング3」でも、ステージごとの勝率は確認できますが一覧性に欠けるうえ、ゲームを開始以来すべての戦績が反映されているため、求める情報からずれることがあります。また、「splat.ink」という非公式のアプリケーションもあり、こちらは自動で戦績が記録される他、詳細な集計もできるもので便利ではあるのですが、非公開APIを利用しているため、利用はややためらわれます[2]。
そこで、戦績を収集し可視化する分析環境をGoogleフォーム+スプレッドシート+colabで作成し、しばらく運用していました。そのままでも良かったのですが、colabが若干使いづらかったので、自分でWebアプリケーションを作成することにしました。
作るもの
次のようなアプリケーションを作成します
- 戦績(ステージ、ルール、モード、結果)を手動で入力する
- ステージ別やルール別の勝率を集計し、可視化する
画像データを機械学習モデルで分類すれば、非公開APIを使わずとも戦績の記録はできるし、実際にSplatoonのゲーム画面から情報を取得するための機械学習モデルを開発されている方もいるようです。しかし、キャプチャーボードが必要になりますし、今回は手動で入力することにしました。PCで入力画面を開いておいて、試合が終わるたびにステージや結果を選びます。
集めたデータは可視化して分析し、練習計画に役立てたりしたいところです。
gradio + fastapi
バックエンドにはfastapiを使います。この選択はこだわりがあるわけではありません[3]。フロントはgradioを使います。gradioはHugging Faceが中心となって開発しているOSSのWebUIフレームワークです。Pandasのデータフレームを表形式で表示できたり、matplotlibのグラフをそのまま表示できたりと、Pythonのデータ分析系のライブラリと非常に相性が良いです。
gradioを利用する理由は大体次の通りです
- グラフを表示したいから
- 上述のcolabで書いた関数を再利用できるから
- 興味があったから
実装
バックエンド
FastAPIで作成します。バックエンドについて特筆すべきことはありません。FastAPIを利用した開発に関心がある人はFastAPI入門を読むと良いでしょう(今回私も参考にしました)。
モデルはこんな感じです
from datetime import datetime
from zoneinfo import ZoneInfo
from sqlalchemy import Column, Integer, String, DateTime
from src.db import Base
class BattleLog(Base):
__tablename__ = "battlelog"
id = Column(Integer, primary_key=True, index=True)
mode = Column(String(255))
rule = Column(String(255))
stage = Column(String(255))
result = Column(String(255))
created_at = Column(DateTime, default=datetime.now(ZoneInfo("Asia/Tokyo")))
エンドポイントは以下の通りです
from typing import List
from fastapi import APIRouter, Depends
from fastapi.exceptions import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
import src.schemas.battlelog as bs
import src.cruds.battlelog as bc
from src.db import get_db
router = APIRouter()
@router.get("/api/v1/battlelog", response_model=List[bs.BattleLog])
async def get_battlelog(db: AsyncSession = Depends(get_db)):
return await bc.get_all_battlelog(db)
@router.post("/api/v1/battlelog", response_model=bs.BattleLogCreateResponse)
async def create_battlelog(
body: bs.BattleLogCreate, db: AsyncSession = Depends(get_db)
):
if body.mode == "レギュラーマッチ" and body.rule != "ナワバリバトル":
raise HTTPException(status_code=400, detail="invalid rule")
if body.rule == "ナワバリバトル" and body.mode != "レギュラーマッチ":
raise HTTPException(status_code=400, detail="invalid mode")
return await bc.create_battlelog(db, body)
@router.delete("/api/v1/battlelog/{id}", response_model=None)
async def delete_battlelog(id: int, db: AsyncSession = Depends(get_db)):
log = await bc.get_battlelog(db, id)
if log is None:
raise HTTPException(status_code=404, detail="not found")
return await bc.delete_battlelog(db, log)
最低限データの追加、全件取得ができればよいです。今の所使う予定はありませんが、削除用のAPIも用意しています。
またスプレッドシートに溜まっているデータをインポートできるように、簡単なスクリプトを作成しました。
import pandas as pd
import requests as r
API_HOST = "http://localhost:8000/api/v1"
def main():
df = pd.read_csv("data/log.csv")
for _, row in df.iterrows():
mode = row['モード']
rule = row['ルール']
stage = row['ステージ']
result = row['結果']
r.post(f"{API_HOST}/battlelog", json={
"mode": mode,
"rule": rule,
"stage": stage,
"result": result
})
if __name__ == "__main__":
main()
フロントエンド
gradioで作ります。
データ登録画面
登録画面ではステージ、ルール、モード[4]、勝敗が記録できるようにします。ステージはプルダウン、他はラジオボタンで値を選びます。
このようにblockを定義しすると
with gr.Blocks() as blocks:
with gr.Row():
gr.Button("ダッシュボード", link="/dashboard")
mode = gr.Radio(
label="モード",
choices=[m.value for m in Mode],
)
rule = gr.Radio(
label="ルール",
choices=[m.value for m in Rule],
)
stage = gr.Dropdown(
label="ステージ",
choices=[m.value for m in Stage],
)
result = gr.Radio(
label="結果",
choices=[m.value for m in Result],
)
submit_btn = gr.Button("登録", variant="primary")
次のような画面が生成されます
ボタンが押下された時の処理は次のように定義します。
def add_record(mode, rule, stage, result):
res = r.post(f"{API_HOST}/battlelog", json={
"mode": mode,
"rule": rule,
"stage": stage,
"result": result
})
if res.status_code != 200:
raise gr.Error("登録できませんでした")
logs = r.get(f"{API_HOST}/battlelog").json()
latest = pd.DataFrame(logs, columns=columns)
return latest
submit_btn.click(
add_record,
inputs=[mode, rule, stage, result],
outputs=[latest],
).success(None, _js="window.alert('登録しました')")
add_record
の中で先程作成したデータ生成用のAPIを叩いています。API側でバリデーションしているので、おかしなデータが飛んできたときは200以外のresponse codeを返します。この時はErrorを発生させます。そうでないときはsuccessの中に書いたjsが実行され「登録しました」というポップアップがでます(操作に対して何かしらのフィードバックが無いと登録されたかどうかわかりにくいため)
本当なら raise gr.Error
ではなく gr.Info
と gr.Warning
を使いたかったのですが、どういうわけかポップアップが出なかったので、代替案としてこのように実装しています。
ダッシュボード
gradioにはmatplotlibで生成したグラフをそのまま表示するためのPlotというコンポーネントがあります。これを利用してダッシュボードを作成します。
行が多いので割愛しますが、愚直に必要な要素を定義していきます
with gr.Blocks() as blocks:
gr.Button("top", link="/top")
submit_btn = gr.Button('集計', variant='primary')
with gr.Row():
winrate_by_rule = gr.Plot(label='ルール別勝率')
winrate_by_mode = gr.Plot(label='モード別勝率')
header_stage = ['ステージ', 'wins', 'loses', 'win_rate', 'ci.low', 'ci.high', 'mean']
header_rule = ['ルール', 'wins', 'loses', 'win_rate', 'ci.low', 'ci.high', 'mean']
with gr.Row():
winrate_by_stage = gr.Plot(label='ステージ別勝率')
winrate_by_stage_df = gr.DataFrame(
label='ステージ別勝率',
headers=header_stage,
)
with gr.Row():
....
勝率を計算するメソッドは次のとおりです。
def aggregate_result(group):
total = group.size
win = group[group == "勝ち"].size
lose = group[group == "負け"].size
ci = binomtest(win, total, 0.5).proportion_ci(confidence_level=0.95)
win_rate = win / total
return pd.Series({
'wins': win,
'loses': lose,
'win_rate': win_rate,
'ci.low': ci.low,
'ci.high': ci.high,
'mean': (ci.low + ci.high) / 2,
})
ステージやルール別に集計すると、対戦回数にはかなりばらつきが生じるので、ある程度の目安とするために、二項検定[5]で信頼区間を算出しています。
この集計データをグラフにするのが、以下のメソッドです。
def win_rate_ci(df):
df = df[::-1]
lower_err = df['mean'] - df['ci.low']
upper_err = df['ci.high'] - df['mean']
fig, ax = plt.subplots()
ax.errorbar(x=df['mean'], y=df.index, xerr=[lower_err, upper_err], fmt='o', linewidth=2, capsize=6)
ax.set(xlim=(0.0, 1.0))
ax.grid()
plt.tight_layout()
return fig
gr.Plot
というコンポーネントにmatplotlibのfigureを渡すと、グラフが描画されます。
gr.Plot(win_rate_ci(stage), label='ステージ別勝率'),
こんな感じ[6]
結構はっきりと傾向がでていますね。「タラポートショッピングパーク」は試合数が少ないので除外して考えるとして、「海女美術大学」「キンメダイ美術館」「マンタマリア号」の勝率が悪いようです。遊んでいて苦しいと感じるのは「ナメロウ金属」のようなステージですが、意外と勝率が悪くないというのも興味深いです。
他にもルール別やモード別、ルール・ステージ別の集計なども作成しました。
※縦に長いので下部は割愛
まとめ
gradioは簡易なUIを作成する用途であればかなり使いやすいと感じました。データ分析や機械学習系のコードとの連携が要件にあるなら有力な選択肢になるのではないでしょうか。
データの方は早速対策の方針が立てられそうな結果が得られたので満足しています。今後は軸を追加したりしてきたいですね。
コードはここにおいてあります。動作の保証は出来ませんが、読んで参考にすることはできると思います。
-
My Nintendoで確認したところ1300時間くらい遊んでました。 ↩︎
-
stat.inkを使っていたせいでアカウントが利用停止になった!という話は今の所聞いたことはありませんが。 ↩︎
-
Pythonでかつ軽量のフレームワークなら何でも良かった。 ↩︎
-
モードというのはレギュラーマッチ、バンカラマッチ(チャレンジ/オープン)、Xマッチの区別のことです。スプラトゥーンを知らない人は何のことかわからないと思いますが、まぁそういうのがあるんだなと思っていただければよいです。 ↩︎
-
二項検定を利用する理由は標本数が少ないからです。3ヶ月の間毎日休まずスプラトゥーンを遊んだとしてもステージ別に集計すると試合数は高々数十試合にしかなりません。 ↩︎
-
11月までのデータを使っているので、「バイガイ亭」と「ネギトロ炭鉱」は入っていません。 ↩︎
Discussion