Streamlit で threading や multiprocessing を使った非同期処理を行う
Streamlit は、お手軽に簡単なウェブアプリを構築できるPythonのフレームワークです。 とても便利で気に入っているので、昨年のAdventCalendarの記事で、使い方をいくつか紹介しました。その中で、
Python で非同期処理を行う場合は、threading や multiprocessing を使って、スクリプト内で別スレッド/プロセスを立ち上げるのが普通ですが、Streamlitでは、その仕組み上そういったことができません。
ということを書いたのですが、「あれ?久々にやってみたら普通にできたぞ。」となったので、サンプルコードを紹介したいと思います。
threading を使った非同期処理の復習
Pythonでは、 threading
パッケージを使うことで、マルチスレッド処理を実現できます。
例えば1秒ごとにカウンターを1だけ増やす処理は、以下のように記述できます。
import threading
import time
class Worker(threading.Thread):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.counter = 0
def run(self):
while True:
time.sleep(1)
self.counter += 1
print(self.counter)
# daemon=Trueとすると、メインスレッド終了時に子スレッドも終了する
worker = Worker(daemon=True)
worker.start()
time.sleep(5)
上記のコードを実行すると、 1秒おきに数字が表示され、終了します。threading.Thread
を継承したクラスを作成し、 させたい処理を run
メソッドに記述するだけなので、とても簡単です。
ここでは、親スレッドの終了によって子スレッドが終了していますが、途中で子スレッドを終了させることもできます。それにはいくつか方法がありますが、threading.Event
を使うと簡単です。例えば、以下のコードを実行すると、1から5まで、1秒おきに数字が表示された後、5秒待ってからプログラムが終了します。
import threading
import time
class Worker(threading.Thread):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.counter = 0
self.should_stop = threading.Event()
def run(self):
while not self.should_stop.wait(0):
time.sleep(1)
self.counter += 1
print(self.counter)
worker = Worker(daemon=True)
worker.start()
time.sleep(5)
# worker.should_stop を発火させる
worker.should_stop.set()
time.sleep(5)
print('done.')
Streamlit と threading を組み合わせる
Streamlitのボタンを使って、この threading.Event
を発火させることで、Streamlitを使ったスレッドの制御ができます。
ポイントは、
- Streamlitでは、ボタン押下等のイベント毎に
rerun
が走る(コードが再実行される)。session_state に worker を保存することで、rerun
が発生しても同一のworker
を見続けられる。 -
worker
を制御(起動/停止)するたびにst.experimental_rerun()
を呼び出すことで、workerの制御と表示が簡単に分離できる。
import threading
import time
import streamlit as st
class Worker(threading.Thread):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.counter = 0
self.should_stop = threading.Event()
def run(self):
while not self.should_stop.wait(0):
time.sleep(1)
self.counter += 1
def main():
if 'worker' not in st.session_state:
st.session_state.worker = None
worker = st.session_state.worker
# worker を制御(起動/停止)する部分
with st.sidebar:
if st.button('Start worker', disabled=worker is not None):
# daemon=True とすることで、Ctrl+C で終了できるようにする
worker = st.session_state.worker = Worker(daemon=True)
worker.start()
st.experimental_rerun()
if st.button('Stop worker', disabled=worker is None):
worker.should_stop.set()
# 終了まで待つ
worker.join()
worker = st.session_state.worker = None
st.experimental_rerun()
# worker の状態を表示する部分
if worker is None:
st.markdown('No worker running.')
else:
st.markdown(f'worker: {worker.getName()}')
placeholder = st.empty()
while worker.is_alive():
placeholder.markdown(f'counter: {worker.counter}')
time.sleep(1)
if __name__ == '__main__':
main()
multiprocessing を利用する
threading
とほぼ同じコードでマルチプロセスにすることができます。大きな違いは、変数のプロセス間共有のために multiprocessing.Value
を使っている点くらいでしょうか。
import multiprocessing
import time
import streamlit as st
class Worker(multiprocessing.Process):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.should_stop = multiprocessing.Event()
self.counter = multiprocessing.Value('i', 0)
def run(self):
while not self.should_stop.wait(0):
time.sleep(1)
self.counter.value += 1
def main():
if 'worker' not in st.session_state:
st.session_state.worker = None
worker: Worker = st.session_state.worker
with st.sidebar:
if st.button('Start worker', disabled=worker is not None):
worker = st.session_state.worker = Worker(daemon=True)
worker.start()
st.experimental_rerun()
if st.button('Stop worker', disabled=worker is None):
worker.should_stop.set()
worker.join()
worker = st.session_state.worker = None
st.experimental_rerun()
if worker is None:
st.markdown('No worker running.')
else:
st.markdown(f'worker: {worker.pid}')
placeholder = st.empty()
while worker.is_alive():
placeholder.markdown(f'counter: {worker.counter.value}')
time.sleep(1)
if __name__ == '__main__':
main()
複数セッションで worker を共有する
さて、 st.session_state
を使うことで、 マルチスレッド/マルチプロセスな処理ができるようになりましたが、上記のコードには、実は問題があります。
st.session_state
はあくまで同一セッション内で状態を保持するためのものなので、例えばブラウザをリロードしたりすると、st.session_state
も新しいものになってしまいます。
すると、リロードする前に起動した worker を停止したり監視したりできなくなってしまいます。
では、どうすればよいかというと、 st.experimental_singleton
を使ってシングルトンオブジェクトを作り、そこにworkerの情報をまとめます。すると、以下の画像の様に、リロードしたり複数のブラウザで立ち上げても、同一のworkerを制御できるようになります(上下2つのブラウザどちらで操作しても、workerの起動/停止ができているのがわかるかと思います)。
具体的なコードは以下のとおりです。ThreadManager
がシングルトンとなって、workerの制御をしてくれています。
コード全体
import dataclasses
import threading
import time
import streamlit as st
class Worker(threading.Thread):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.counter = 0
self.should_stop = threading.Event()
def run(self):
while not self.should_stop.wait(0):
time.sleep(1)
self.counter += 1
@st.experimental_singleton
@dataclasses.dataclass
class ThreadManager:
worker = None
def get_worker(self):
return self.worker
def is_running(self):
return self.worker is not None and self.worker.is_alive()
def start_worker(self):
if self.worker is not None:
self.stop_worker()
self.worker = Worker(daemon=True)
self.worker.start()
return self.worker
def stop_worker(self):
self.worker.should_stop.set()
self.worker.join()
self.worker = None
def main():
thread_manager = ThreadManager()
with st.sidebar:
if st.button('Start worker', disabled=thread_manager.is_running()):
worker = thread_manager.start_worker()
st.experimental_rerun()
if st.button('Stop worker', disabled=not thread_manager.is_running()):
thread_manager.stop_worker()
st.experimental_rerun()
if not thread_manager.is_running():
st.markdown('No worker running.')
else:
worker = thread_manager.get_worker()
st.markdown(f'worker: {worker.getName()}')
placeholder = st.empty()
while worker.is_alive():
placeholder.markdown(f'counter: {worker.counter}')
time.sleep(1)
# 別セッションでの更新に追従するために、定期的にrerunする
time.sleep(1)
st.experimental_rerun()
if __name__ == '__main__':
main()
まとめ
Streamlitと threading
や multiprocessing
の組み合わせによる非同期処理の方法を紹介しました。st.experimental_singleton
や st.experimental_rerun()
を使うことで、非同期処理もきれいにかけそうです。
Discussion