😸

Pythonで並列処理のすすめ

2022/05/05に公開

Python(CPython) の並列処理で適切にパフォーマンス改善する方法を解説します。

一般に並列処理で思い付くのはスレッド処理ですが、Python には GIL(グローバル・インタプリタ・ロック)と呼ばれるロック機構があるため、その方法ではパフォーマンス改善を望めないことがあります。
なぜならば Python コードを実行できるのは GIL を保持したスレッドだけなので複数スレッドで並列処理をしたつもりでも実際に実行されるのは1スレッドだけで、実質的には並行処理となってしまうからです。

当記事はスレッドベースの並行・並列処理の問題と、それに代わるプロセスベースの並列処理のメリットデメリットや大量のデータを扱う場合のパフォーマンス特性を解説します。
検証には Python 3.9.12 / Ubuntu 20.04.4 LTS を利用しています。

GIL の影響を見てみる

以下のような素朴なアルゴリズムで、指定した数値までに含まれる素数の数をカウントする関数を用意しました。
並行・並列処理から実行結果を受け取るためには return ではなく Queue を利用します。

from queue import Queue

def count_primes(num: int, queue: Queue) -> None:
    primes = 0
    for i in range(2, num + 1):
        for j in range(2, i):
            if i % j == 0:
                break
        else:
            primes += 1
    queue.put(primes)

まずは以下のようにスレッドを1つ生成して実行します。

from threading import Thread

queue1 = Queue()
thread1 = Thread(target=count_primes, args=(100000, queue1))
thread1.start()
thread1.join()
print(queue1.get())

100000 までに含まれる素数は 9592個あり、その計算には 30秒かかりました。

$ time python count_primes_thread.py 
9592

real    0m30.982s
user    0m30.965s
sys     0m0.010s

次は2つのスレッドで同時に実行します。

queue1 = Queue()
thread1 = Thread(target=count_primes, args=(100000, queue1))
thread1.start()

queue2 = Queue()
thread2 = Thread(target=count_primes, args=(100000, queue2))
thread2.start()

thread1.join()
print(queue1.get())
thread2.join()
print(queue2.get())

結果はほぼ2倍で約1分間かかりました。

$ time python count_primes_thread.py 
9592
9592

real    1m1.449s
user    1m1.665s
sys     0m0.730s

実行環境は複数コアのCPUなので並列処理できるはずですが、GIL の影響で同時実行が制限されて並行処理になっています。

ではプロセスベースの並列処理でパフォーマンスを改善する方法を考えます。

プロセスベースの並列処理

Python の GIL はインタプリタ単位で作用するので fork して生成された子プロセスは親プロセスの GIL の影響を受けずに並列処理できます。
そこで利用するのが multiprocessing モジュールで、スレッドとよく似たインターフェースでプロセスを扱えます。

ちなみに Python には子プロセスの生成方法がいくつか用意されていますが、当記事では Linux 環境のデフォルトである fork を前提としています。

先ほどのスレッド(threading)を利用したプログラムを multiprocessing モジュールで書き換えてみます。
とは言っても、両者はほぼ共通のインターフェースを持つため、この場合は import するクラスを置き換えるだけの修正で動きます。

from multiprocessing import Process, Queue

def count_primes(num: int, queue: Queue) -> None:
    primes = 0
    for i in range(2, num + 1):
        for j in range(2, i):
            if i % j == 0:
                break
        else:
            primes += 1
    queue.put(primes)

queue1 = Queue()
process1 = Process(target=count_primes, args=(100000, queue1))
process1.start()

queue2 = Queue()
process2 = Process(target=count_primes, args=(100000, queue2))
process2.start()

process1.join()
print(queue1.get())
process2.join()
print(queue2.get())

実行すると先ほど1分間以上かかった処理が34秒で終了し、約2倍に高速化しました。

$ time python count_primes_process.py 
9592
9592

real    0m34.479s
user    1m8.574s
sys     0m0.010s

プロセスベースの並列処理なら Python コードも並列実行できることが分かります。

ただしスレッドと似たインターフェースで扱いやすいのは良いのですが、実際にはプロセスとスレッドは本質的に異なるものであり不適切な利用をすると望まぬ結果になります。そこで次はプロセスベースの並列処理でハマりやすいポイントを解説します。

スレッドはメモリ空間を共有、プロセスは独立

スレッドは親プロセスとメモリ空間を共有するので、例えばグローバル変数の値を変更すればスレッドの呼び出し元もその変更結果を取得できます。
一方で子プロセスは fork のタイミングで親プロセスからコピーされたメモリ空間を持ち、親プロセスの持っていた変数を参照することができますが、子プロセス側で変数を変更しても子プロセスが終了すると破棄されて親プロセス側はその値を取得できません。

実際に試してみます。

from multiprocessing import Process
from threading import Thread

global_value = 0

def worker() -> None:
    global global_value
    global_value += 1
    print(f"in worker               : {global_value=}")

# スレッドを生成・実行開始して終了まで待つ
print(f"before thread execution : {global_value=}")
thread = Thread(target=worker)
thread.start()
thread.join()
print(f"after thread execution  : {global_value=}")

# プロセスを生成・実行開始して終了まで待つ
print(f"before process execution: {global_value=}")
process = Process(target=worker)
process.start()
process.join()
print(f"after process execution : {global_value=}")

実行結果は以下のようになりました。

$ time python thread_process.py 
before thread execution : global_value=0
in worker               : global_value=1
after thread execution  : global_value=1
before process execution: global_value=1
in worker               : global_value=2
after process execution : global_value=1

スレッドは global_value1 加算した結果を呼び出し側でも受け取れています。
一方のプロセスは、worker 関数内の出力で global_value=2 となっているのでグローバル変数の値を参照できています。しかし worker 関数が終了(子プロセスが終了)して親プロセスに制御が戻ると global_value=1 に戻っています。なぜならば global_value=2 に更新されたのは子プロセスの中だけで、親プロセス側の global_value は影響を受けないからです。

ちなみにこのサンプルプログラムは1つのプログラム内でスレッドとプロセスを両方とも生成していますが、プロセス生成(fork)のタイミングで複数スレッドが動作する状況は本質的に安全ではない点は注意してください。プロセスベースの並列処理を起動する前にスレッドを終了しておくと安全です。※公式ドキュメントのforkに関する説明を参照

データの受け渡しはプロセス間通信

次はデータの受け渡しに利用した Queue クラスについて考えてみます。

スレッドはメモリ空間を共有するので、データのやりとりは同じメモリ空間内で直接行うことができます。
一方でプロセスベースの並列処理で生成される子プロセスは親子関係こそあるものの相互に独立したプロセスなので、データのやりとりにはプロセス間通信を利用してバイト列を送受信します。

スレッドとプロセスは両方とも似たインターフェースを持つ Queue クラスでやりとりしましたが、実はその内部実装は大きく異なります。

動作原理の違いを理解するために、まずはスレッドから defaultdict(lambda: 1) を受け取ってみます。
defaultdict は存在しないキーへアクセスされた場合のデフォルト値を callable の実行結果とすることができるので、lambda と連携してよく使われます。

以下のコードは指定した数値までに含まれる素数をキーに設定した defaultdict を返却します。
defaultdict は存在しないキーにアクセスしたタイミングで callable が実行されてキーと値が生成されまるため、primes[i] の部分で i に対して lambda: 1 の実行結果である 1 がセットされます。

from collections import defaultdict
from threading import Thread
from queue import Queue

def get_primes(num: int, queue: Queue) -> None:
    primes = defaultdict(lambda: 1)
    for i in range(2, num + 1):
        for j in range(2, i):
            if i % j == 0:
                break
        else:
            primes[i]
    queue.put(primes)

queue = Queue()
thread = Thread(target=get_primes, args=(100000, queue))
thread.start()
thread.join()
primes = queue.get()
print(sum(primes.values()))

スレッドベースの並行・並列処理では Queue を介して問題なく defaultdict を受け取ることができました。

では次にプロセスベースの並列処理で同じことをやってみます。

from collections import defaultdict
from multiprocessing import Process
from multiprocessing import Queue

def get_primes(num: int, queue: Queue) -> None:
    primes = defaultdict(lambda: 1)
    for i in range(2, num + 1):
        for j in range(2, i):
            if i % j == 0:
                break
        else:
            primes[i]
    queue.put(primes)

queue = Queue()
process = Process(target=get_primes, args=(100000, queue))
process.start()
process.join()
primes = queue.get()
print(sum(primes.values()))

しかしこれは残念ながら実行すると永遠に終了しません。

multiprocessing.Queue はプロセス間通信を行うクラスなので受け渡すデータをシリアライズしてバイト列に変換し、受け取り側はデシリアライズします。しかし defaultdict に渡した lambda: 1 はシリアライズできません。
子プロセス側は素数の計算処理をして Queue へデータを受け渡しますが、シリアライズに失敗します。一方の親プロセスは primes = queue.get()Queue からデータが送られるのを待つためストールします。

シリアライズには pickle を利用するので pickle が対応しているオブジェクトはシリアライズが可能です。例えば以下のようにdefaultdict に渡す callablelambda: 1 ではなく fork 前に def で事前に定義した関数に変更すれば問題なく処理できるようになります。

def return_one():
    return 1

def get_primes(num: int, queue: Queue) -> None:
    primes = defaultdict(return_one)
    ...

子プロセスにデータを渡す方法

親プロセスから子プロセスに対して大量のデータを渡したいとします。
しかし生成済みの子プロセスにデータを送信する手段は前述の通りプロセス間通信になるためパフォーマンスは必ずしも良くありません。

一方で子プロセスを作成する負荷(つまり fork の負荷)はそれほど高くないため、親プロセスから子プロセスに大量データを渡したい場合は、親プロセス側でデータを生成した後に子プロセスを生成する方法がパフォーマンス的に優れています。
ちなみに fork で子プロセスにリソースがコピーされるものの、CoW(コピーオンライト)という仕組みのおかげで読み取りだけなら実際のメモリ領域を消費しません。

では実際に Processargs を利用して事前に生成しておいた2GBの巨大な文字列を渡してみます。
この方法ならば子プロセス側でも親プロセスで生成した2GBの巨大な文字列をメモリ空間から直接取得できるはずです。

from multiprocessing import Process, Queue

def worker(target: str, queue: Queue):
    queue.put(len(target))

# 2GB相当の文字列を生成
huge_str = "A" * 1 * 1024 * 1024 * 1024 * 2

queue = Queue()

# 2GB相当の文字列を渡して子プロセスを開始
process = Process(target=worker, args=(huge_str, queue))
process.start()
process.join()

# 実行結果の受け取り
print(queue.get())

子プロセスは fork したタイミングで親プロセスが保持するメモリ空間のコピーを持つので1秒もかからず処理が終了しました。

2147483648

real    0m0.797s
user    0m0.359s
sys     0m0.439s

次は比較として子プロセスを生成した後に Queue で2GBのデータを送信したところ 30秒以上かかってしまいました。

from multiprocessing import Process, Queue

def worker(queue: Queue):
    queue.put(len(queue.get()))

queue = Queue()

# 子プロセスを開始
process = Process(target=worker, args=(queue,))
process.start()

# 2GB相当の文字列を生成して子プロセスに送信
huge_str = "A" * 1 * 1024 * 1024 * 1024 * 2
queue.put(huge_str)
process.join()

# 実行結果の受け取り
print(queue.get())
2147483648

real    0m30.700s
user    0m5.584s
sys     0m26.285s

プロセス間通信に時間がかかる上に、親プロセス側はシリアライズして子プロセスにデータを送信し、子プロセスはそれをデシリアライズして受け取るという流れで、同じ変数に対するメモリ領域を何度も確保してコピーした結果パフォーマンスが劣化しました。

親プロセスから子プロセスへデータを渡す場合は親プロセス側でデータを作成した後に子プロセスを生成する方法が CoW を効率的に利用できて良さそうです。
ただし子プロセスから親プロセスへデータを渡す場合は同じ方法を使うことができないので、multiprocessing.Queue などを利用したプロセス間通信に頼らざるをえません。

まとめ

Python コードを並列処理したい場合はプロセスベースでやりましょう、でも注意点もありますよ。という内容でした。
スレッドを使った並列処理を実装しようとしても、実際には並行処理になっていて「なぜか速くならないな・・・」と思っている方に届けば幸いです。

ちなみにスレッドで問題になった GIL はシステムコール実行中は解放されます。そのためストレージIOやDB操作がボトルネックならばスレッドでも並列処理が可能なのでパフォーマンスが改善します。
スレッドならば扱いやすいので両者の特徴をしっかり考えた上で適切な並列処理の方法を選びましょう。

Discussion