🍺

Streamlit で threading や multiprocessing を使った非同期処理を行う

2022/03/09に公開

Streamlit は、お手軽に簡単なウェブアプリを構築できるPythonのフレームワークです。 とても便利で気に入っているので、昨年のAdventCalendarの記事で、使い方をいくつか紹介しました。その中で、

Python で非同期処理を行う場合は、threading や multiprocessing を使って、スクリプト内で別スレッド/プロセスを立ち上げるのが普通ですが、Streamlitでは、その仕組み上そういったことができません。

ということを書いたのですが、「あれ?久々にやってみたら普通にできたぞ。」となったので、サンプルコードを紹介したいと思います。

threading を使った非同期処理の復習

Pythonでは、 threading パッケージを使うことで、マルチスレッド処理を実現できます。
例えば1秒ごとにカウンターを1だけ増やす処理は、以下のように記述できます。

import threading
import time


class Worker(threading.Thread):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.counter = 0
        
    def run(self):
        while True:
            time.sleep(1)
            self.counter += 1
            print(self.counter)

# daemon=Trueとすると、メインスレッド終了時に子スレッドも終了する
worker = Worker(daemon=True)
worker.start()
time.sleep(5)

上記のコードを実行すると、 1秒おきに数字が表示され、終了します。threading.Thread を継承したクラスを作成し、 させたい処理を run メソッドに記述するだけなので、とても簡単です。

ここでは、親スレッドの終了によって子スレッドが終了していますが、途中で子スレッドを終了させることもできます。それにはいくつか方法がありますが、threading.Eventを使うと簡単です。例えば、以下のコードを実行すると、1から5まで、1秒おきに数字が表示された後、5秒待ってからプログラムが終了します。

import threading
import time


class Worker(threading.Thread):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.counter = 0
        self.should_stop = threading.Event()
        
    def run(self):
        while not self.should_stop.wait(0):
            time.sleep(1)
            self.counter += 1
            print(self.counter)

worker = Worker(daemon=True)
worker.start()
time.sleep(5)
# worker.should_stop を発火させる
worker.should_stop.set()
time.sleep(5)
print('done.')

Streamlit と threading を組み合わせる

Streamlitのボタンを使って、この threading.Event を発火させることで、Streamlitを使ったスレッドの制御ができます。
ポイントは、

  1. Streamlitでは、ボタン押下等のイベント毎に rerun が走る(コードが再実行される)。session_state に worker を保存することで、 rerun が発生しても同一の worker を見続けられる。
  2. worker を制御(起動/停止)するたびに st.experimental_rerun() を呼び出すことで、workerの制御と表示が簡単に分離できる。
import threading
import time

import streamlit as st


class Worker(threading.Thread):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.counter = 0
        self.should_stop = threading.Event()
        
    def run(self):
        while not self.should_stop.wait(0):
            time.sleep(1)
            self.counter += 1

def main():
    if 'worker' not in st.session_state:
        st.session_state.worker = None
    worker = st.session_state.worker

    # worker を制御(起動/停止)する部分
    with st.sidebar:
        if st.button('Start worker', disabled=worker is not None):
            # daemon=True とすることで、Ctrl+C で終了できるようにする
            worker = st.session_state.worker = Worker(daemon=True)
            worker.start()
            st.experimental_rerun()
            
        if st.button('Stop worker', disabled=worker is None):
            worker.should_stop.set()
            # 終了まで待つ
            worker.join()
            worker = st.session_state.worker = None
            st.experimental_rerun()
    
    # worker の状態を表示する部分
    if worker is None:
        st.markdown('No worker running.')
    else:
        st.markdown(f'worker: {worker.getName()}')
        placeholder = st.empty()
        while worker.is_alive():
            placeholder.markdown(f'counter: {worker.counter}')
            time.sleep(1)


if __name__ == '__main__':
    main()

multiprocessing を利用する

threading とほぼ同じコードでマルチプロセスにすることができます。大きな違いは、変数のプロセス間共有のために multiprocessing.Value を使っている点くらいでしょうか。

import multiprocessing
import time

import streamlit as st


class Worker(multiprocessing.Process):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.should_stop = multiprocessing.Event()
        self.counter = multiprocessing.Value('i', 0)
        
    def run(self):
        while not self.should_stop.wait(0):
            time.sleep(1)
            self.counter.value += 1

def main():
    if 'worker' not in st.session_state:
        st.session_state.worker = None
    worker: Worker = st.session_state.worker

    with st.sidebar:
        if st.button('Start worker', disabled=worker is not None):
            worker = st.session_state.worker = Worker(daemon=True)
            worker.start()
            st.experimental_rerun()
            
        if st.button('Stop worker', disabled=worker is None):
            worker.should_stop.set()
            worker.join()
            worker = st.session_state.worker = None
            st.experimental_rerun()


    if worker is None:
        st.markdown('No worker running.')
    else:
        st.markdown(f'worker: {worker.pid}')
        placeholder = st.empty()
        while worker.is_alive():
            placeholder.markdown(f'counter: {worker.counter.value}')
            time.sleep(1)


if __name__ == '__main__':
    main()

複数セッションで worker を共有する

さて、 st.session_state を使うことで、 マルチスレッド/マルチプロセスな処理ができるようになりましたが、上記のコードには、実は問題があります。
st.session_state はあくまで同一セッション内で状態を保持するためのものなので、例えばブラウザをリロードしたりすると、st.session_state も新しいものになってしまいます。
すると、リロードする前に起動した worker を停止したり監視したりできなくなってしまいます。

では、どうすればよいかというと、 st.experimental_singleton を使ってシングルトンオブジェクトを作り、そこにworkerの情報をまとめます。すると、以下の画像の様に、リロードしたり複数のブラウザで立ち上げても、同一のworkerを制御できるようになります(上下2つのブラウザどちらで操作しても、workerの起動/停止ができているのがわかるかと思います)。

具体的なコードは以下のとおりです。ThreadManager がシングルトンとなって、workerの制御をしてくれています。

コード全体
import dataclasses
import threading
import time

import streamlit as st


class Worker(threading.Thread):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.counter = 0
        self.should_stop = threading.Event()
        
    def run(self):
        while not self.should_stop.wait(0):
            time.sleep(1)
            self.counter += 1


@st.experimental_singleton
@dataclasses.dataclass
class ThreadManager:
    worker = None
    
    def get_worker(self):
        return self.worker
    
    def is_running(self):
        return self.worker is not None and self.worker.is_alive()
    
    def start_worker(self):
        if self.worker is not None:
            self.stop_worker()
        self.worker = Worker(daemon=True)
        self.worker.start()
        return self.worker
    
    def stop_worker(self):
        self.worker.should_stop.set()
        self.worker.join()
        self.worker = None



def main():
    thread_manager = ThreadManager()

    with st.sidebar:
        if st.button('Start worker', disabled=thread_manager.is_running()):
            worker = thread_manager.start_worker()
            st.experimental_rerun()
            
        if st.button('Stop worker', disabled=not thread_manager.is_running()):
            thread_manager.stop_worker()
            st.experimental_rerun()
    
    if not thread_manager.is_running():
        st.markdown('No worker running.')
    else:
        worker = thread_manager.get_worker()
        st.markdown(f'worker: {worker.getName()}')
        placeholder = st.empty()
        while worker.is_alive():
            placeholder.markdown(f'counter: {worker.counter}')
            time.sleep(1)

    # 別セッションでの更新に追従するために、定期的にrerunする
    time.sleep(1)
    st.experimental_rerun()


if __name__ == '__main__':
    main()

まとめ

Streamlitと threadingmultiprocessing の組み合わせによる非同期処理の方法を紹介しました。st.experimental_singletonst.experimental_rerun() を使うことで、非同期処理もきれいにかけそうです。

Discussion