🚀
Streamlitで機械学習モデルをキャッシュして、推論表示を早くする話
背景
機械学習モデルを使って、予測を実行するアプリケーションをstreamlitで作っていたのですが、予測実行のたびにモデル読み込みで時間がかかっていたので、これを省略するためにchatGPTさんに解決方法を聴きました。
やったこと
before
以下がアプリコードです。time.sleep(10)が実際には機械学習モデルのpickle.loadになるわけですが、これだと推論ボタンを押すたびにモデルをloadして走ることになります。
import streamlit as st
import pickle
import time
def load_model():
# st.spinnerを使用してローディングアニメーションを表示
with st.spinner('10秒間待機しています...'):
# 10秒間待機
time.sleep(10)
model = 'modelだよ'
return model
# Streamlitアプリのメイン関数
def main():
# ボタンを表示
if st.button('ローディングを開始'):
# モデルをロード
model = load_model()
st.write(model)
# アプリを実行
if __name__ == "__main__":
main()
折りたたむ
after(モデルインスタンスをキャッシュする)
stremalitは常にコードを実行し続けている理解なので、pickle.loadしたものをインスタンスにして持つのはイメージが無かったのですが、st.cacheを使うことで可能なようです。
import streamlit as st
import pickle
import time
# モデルをロードする関数に@st.cacheデコレータを適用
@st.cache(allow_output_mutation=True)
def load_model():
# st.spinnerを使用してローディングアニメーションを表示
with st.spinner('10秒間待機しています...'):
# 10秒間待機
time.sleep(10)
model = 'modelだよ'
return model
# Streamlitアプリのメイン関数
def main():
# モデルをロード
model = load_model()
# ボタンを表示
if st.button('ローディングを開始'):
st.write(model)
# アプリを実行
if __name__ == "__main__":
main()
本当はアプリケーションを動かしている動画を張り付けたかったのですが、初心者に付き保留します。後日更新予定です。
Discussion