👏
Pythonのmultiprocessingで子プロセスの異常終了を検知したい時の対処方法
問題
multiprocessingライブラリで並列化処理を実施している際、子プロセスが異常停止する(OOM KillerなどにKillされるなど)と、親プロセスがハングアップしてしまう問題があります。
これはCPythonのIssueでも報告されており、対処方法が日本語記事にはあまり見当たらなかったので備忘録として記載しておきます。
再現コードとして、下記を用意しました。実行してみると、親プロセスが子プロセスの異常終了を検知できず応答なしの状態になることが確認できます。
import multiprocessing
import sys
def worker_fn(dummy: int):
print(dummy)
if dummy == 3:
sys.exit(1)
return dummy
def do_parallel(inputs: list[int]):
with multiprocessing.Pool(processes=3) as pool:
results = pool.map(worker_fn, inputs)
return results
if __name__ == "__main__":
do_parallel([1, 2, 3])
対処方法
前述したCPythonのIssueでも案内されている通り、multiprocessingの代わりにconcurrent.futuresを用います。
concurrent.futuresはPythonの非同期処理を扱うための高レベルAPIです。詳細は以下の公式リファレンスを参照してください。
実装例を以下に示します。
import concurrent.futures as cf
import sys
def worker_fn(dummy: int):
print(dummy)
if dummy == 3:
sys.exit(1)
return dummy
def do_parallel_2(inputs: list[int]):
with cf.ProcessPoolExecutor(max_workers=3) as executor:
futures: list[cf.Future] = [
executor.submit(worker_fn, v) for v in inputs
]
# Wait until all execution are completed.
cf.wait(futures)
for future in futures:
# Detect failure
if future.exception() is not None:
raise Exception("Subprocess is failed.")
# flatten
return [f.result() for f in futures]
if __name__ == "__main__":
# the code below raises Exception
# do_parallel_2([1, 2, 3])
results = do_parallel_2([1, 2, 4])
# stdout: [1, 2, 4]
print(results)
以降に実装のポイントを簡単に示しておきます。
- ProcessPoolExecutorインスタンスを生成する(executor)
- executorに実行する関数と引数をペアで渡す(submitする)ことで、Futureオブジェクトを作る
- Futureオブジェクトは非同期に実行された関数の結果を格納するクラス(JavaScriptのPromiseと扱いは似ています)
- 非同期な関数がすべて終了状態になるのを待つ(cf.wait)
- Futureオブジェクトの中身をみることで、失敗の有無や戻り値を確認できる。
- Future.exception(): 実行時に発生した例外を返す
- Future.result():実行した関数の戻り値を返す
補足
上記では、チャンクサイズなどを考慮していませんが、worker_fnをうまく変更することで実現できます。
Discussion