Closed11
【強化学習】Replay Buffer をRayで利用するための調査・情報収集
cpprb を Ray で使いたいとの要望があり、調査・情報収集のためのスクラップ
GitLab の repo にも issue を立てた
ありがたいことに報告者がPR (のドラフト) を作ってくれた。
[ポイント]
- データは
multiprocessing.shared_memory.SharedMemory
に配置 -
Lock
/Event
は、multiprocessing.managers.SyncManager
経由で同期-
multiprocessing.current_process().auth_key
を指定することで、fork
/spawn
以外で作成したプロセスとも通信できるようになる。
-
[気になるところ]
- Python 3.7 用のフォールバックがほしい
-
SyncManager
経由の同期はいつもではなく必要なときだけにしたい -
SyncManager
は立ち上げっぱなしで、終了していない。
ray.util.inspect_serializability
関数で Serialize化できるかをチェックできる。
内部に持っている mmap.mmap
がシリアライズできない。
from cpprb import MPPrioritizedReplayBuffer
from ray.util import inspect_serializability
rb = MPPrioritizedReplayBuffer(4, {"done": {}})
ray.util.inspect_serializability(rb, name="MPPrioritizedReplayBuffer")
================================================================================
Checking Serializability of <cpprb.PyReplayBuffer.MPPrioritizedReplayBuffer object at 0x7f83a83d6220>
================================================================================
!!! FAIL serialization: cannot pickle 'mmap.mmap' object
WARNING: Did not find non-serializable object in <cpprb.PyReplayBuffer.MPPrioritizedReplayBuffer object at 0x7f83a83d6220>. This may be an oversight.
================================================================================
Lock オブジェクトが渡せない例として挙げられている
おそらく共有メモリーも渡せないと思われる。
Serializeして Worker 間を渡すデータ以外に、共有メモリー上に配置する Plasma データストアがある。
ただし、ここは不変なデータ置き場のため、Replay Buffer を配置することはできない。
Python 3.8 で追加された SharedMemory
クラスは一意な名前付きの共有メモリを作成でき、他のプロセスからも名前で引き当てる事ができる。
要検証だが、この方法であれば、Rayのシリアライズを回避して共有メモリを保持できるかもしれない。
デメリットは、未だサポート内の Python 3.7 で利用できないこと。
Workerが単一マシン上にあるならば、これがどちらのケースも意図通りに動く。
import multiprocessing
from multiprocessing.shared_memory import SharedMemory
import numpy as np
import ray
def ray_test():
ray.init()
shm = SharedMemory(create=True, size=32 * 3)
a = np.ndarray(shape=(3,), dtype=np.int32, buffer=shm.buf)
print(a)
@ray.remote
def add(name, shape, dtype):
m = SharedMemory(name=name)
b = np.ndarray(shape=shape, dtype=dtype, buffer=m.buf)
print(b)
b += 2
print(b)
@ray.remote
def add_shm(shm, shape, dtype):
b = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf)
print(b)
b += 2
print(b)
ray.get(add.remote(shm.name, a.shape, a.dtype))
print(a)
ray.get(add_shm.remote(shm, a.shape, a.dtype))
print(a)
shm.close()
shm.unlink()
if __name__ == "__main__":
ray_test()
このスクラップは2022/02/27にクローズされました