🚀

Streamlitで機械学習モデルをキャッシュして、推論表示を早くする話

2024/01/24に公開

背景

機械学習モデルを使って、予測を実行するアプリケーションをstreamlitで作っていたのですが、予測実行のたびにモデル読み込みで時間がかかっていたので、これを省略するためにchatGPTさんに解決方法を聴きました。

やったこと

before

以下がアプリコードです。time.sleep(10)が実際には機械学習モデルのpickle.loadになるわけですが、これだと推論ボタンを押すたびにモデルをloadして走ることになります。

import streamlit as st
import pickle
import time

def load_model():
    # st.spinnerを使用してローディングアニメーションを表示
    with st.spinner('10秒間待機しています...'):
        # 10秒間待機
        time.sleep(10)
    model = 'modelだよ'
    return model

# Streamlitアプリのメイン関数
def main():

    # ボタンを表示
    if st.button('ローディングを開始'):
        # モデルをロード
        model = load_model()
        st.write(model)

# アプリを実行
if __name__ == "__main__":
    main()
折りたたむ

after(モデルインスタンスをキャッシュする)

stremalitは常にコードを実行し続けている理解なので、pickle.loadしたものをインスタンスにして持つのはイメージが無かったのですが、st.cacheを使うことで可能なようです。

import streamlit as st
import pickle
import time

# モデルをロードする関数に@st.cacheデコレータを適用
@st.cache(allow_output_mutation=True)
def load_model():
    # st.spinnerを使用してローディングアニメーションを表示
    with st.spinner('10秒間待機しています...'):
        # 10秒間待機
        time.sleep(10)
    model = 'modelだよ'
    return model

# Streamlitアプリのメイン関数
def main():
    # モデルをロード
    model = load_model()

    # ボタンを表示
    if st.button('ローディングを開始'):
        st.write(model)

# アプリを実行
if __name__ == "__main__":
    main()

本当はアプリケーションを動かしている動画を張り付けたかったのですが、初心者に付き保留します。後日更新予定です。

Discussion