[強化学習] Ape-Xの高速化を実現
記事としては Zenn 初投稿。スクラップ記事で色々まとめたことを元に開発した。
TL;DR
- 強化学習のExperience Replay (経験再生) のためのライブラリ cpprb を開発してるよ
- 共有メモリによるデータ共有と効率的なロックの実装で、Ape-X を高速化したよ
- cpprb を利用する際の Ape-X のサンプル実装もあるよ
1. はじめに
2019年の1月頃から、強化学習のExperience Replay (経験再生) のためのバッファライブラリcpprbを開発している。
上の紹介記事にも書いたが、自由度と効率性を重視して開発を行っている。
保存するトランザクションの値を柔軟に設定することが可能で、また Prioritized Experience Replay に利用するSegment TreeをC++で実装することで、Pythonで実装するよりも高速化を実現している。
使ってみていいなと思ったら、以下のレポジトリにスターをつけてくれると嬉しい。
2. 今回の課題
TensorFlow や PyTorch などの深層学習のライブラリはGPUに最適化されている一方、強化学習を行う環境(シミュレーターなど)はGPUで高速に実行できるものばかりではなく、強化学習を行う際の律速となることがある。
そんな条件下では、強化学習にかかる時間を短縮させるひとつの手段として、Ape-X のように環境の探索をネットワークの学習と分離し、かつ複数同時に実行させる方法が取られている。
並列処理を実現するにあたって、プロセス間でデータを高速かつ整合性を持ってやり取りすることが必要になってくる。 (Python (CPython) では GIL があるため、並列に複数の処理を実行するには、マルチ・スレッドではなく、マルチ・プロセスを利用する。)
プロセス間でデータを受け渡す単純な方法は、プロキシ (multiprocessing.managers.SyncManager
) や、キュー (multiprocessing.Queue
) などを利用する方法だが、データ量が大きくなると遅くなりやすい。
またデータの整合性を担保するために、グローバル・バッファを操作する際にロックをかける必要があるが、単純にバッファ全体をロックすると複数プロセスがロック待ちをする時間が非常に長くなりやすい。
# explorerの単純な実装例
if local_buffer.get_stored_size() > local_size:
local_sample = local_buffer.get_all_transitions()
local_buffer.clear()
with lock:
global_buffer.add(**local_sample)
# learnerの単純な実装例
with lock:
sample = global_buffer.sample(batch_size)
TD = model.compute_TD_error(sample)
global_buffer.update_priorities(sample["indexes"],np.abs(TD))
まとめると、Ape-Xの高速化には、以下の点の改善が必要だと考える。
- 高速なプロセス間通信
- データの整合性を保ちつつロックを効率化
3. cpprbの対応
これまでの ReplayBuffer
PrioritizedReplayBuffer
クラスに加え、マルチプロセスサポート機能を実装した MPReplayBuffer
MPPrioritizedReplayBuffer
クラスを新たに実装した。
詳細は公式ドキュメントに記載しているが、この記事では一部を抜粋して紹介する。
3.1 共有メモリ
プロセス間のデータを高速に同期するために、内部データを共有メモリに載せる方式を採用した。
メモリの物理的に同じ領域を複数のプロセスから同時に利用する方式のため、無駄なデータコピーを削減でき高速化が見込まれる。
これまで numpy.ndarray で内部データを保持していたので、互換性を持つように共有メモリ上に ndarray のビューを配置する次のようなクラスを実装した。
(Cython を利用しているが、この部分実装には特に必要なかったようにも思う。)
from multiprocessing.sharedctypes import RawArray
import numpy as no
cdef class SharedBuffer:
cdef data
cdef view
def __init__(self,shape,dtype,data=None):
ctype = np.ctypeslib.as_ctypes_type(dtype)
len = int(np.array(shape,copy=False,dtype="int").prod())
self.data = data or RawArray(ctype,len)
self.view = np.ctypeslib.as_array(self.data)
self.view.shape = shape
def __getitem__(self,key):
return self.view[key]
def __setitem__(self,key,value):
self.view[key] = value
def __reduce__(self):
return (SharedBuffer,(self.view.shape,self.view.dtype,self.data))
3.2 効率的なロック
次の図は Prioritized Replay Buffer のデータ構造である。
バッファ全体ではなく、クリティカルセクションだけをロックする方針でロックを実装していく。
出典: https://ymd_h.gitlab.io/cpprb/features/ape-x/
3.2.1 同時書き込み
cpprb では Replay Buffer はリングバッファで実装されており、一連の書き込みは異なるアドレスに順番に書き込まれる。1つのプロセスが書き込んでいる間にバッファ全体をロックする必要はなく、書き込み先インデックスを適切にロックして参照・増加させれば、複数のプロセスが同時に異なるアドレスに書き込んでも問題ない。
ReplayBuffer
クラスの中にばらばらで実装されていたインデックス操作を RingBufferIndex
と切り出し、それを継承する形で操作時にロックする ProcessSafeRingBufferIndex
を実装した。 (副作用として、普段の ReplayBuffer
もインデックスを共有メモリ上に配置することになったが、何のロックもアトミックな操作もしていないため速度面での影響は実験した限り無かった。)
cdef class RingBufferIndex:
"""Ring Buffer Index class
"""
cdef index
cdef buffer_size
cdef is_full
def __init__(self,buffer_size):
self.index = RawValue(ctypes.c_size_t,0)
self.buffer_size = RawValue(ctypes.c_size_t,buffer_size)
self.is_full = RawValue(ctypes.c_int,0)
cdef size_t get_next_index(self):
return self.index.value
cdef size_t fetch_add(self,size_t N):
"""
Add then return original value
Parameters
----------
N : size_t
value to add
Returns
-------
size_t
index before add
"""
cdef size_t ret = self.index.value
self.index.value += N
if self.index.value >= self.buffer_size.value:
self.is_full.value = 1
while self.index.value >= self.buffer_size.value:
self.index.value -= self.buffer_size.value
return ret
cdef void clear(self):
self.index.value = 0
self.is_full.value = 0
cdef size_t get_stored_size(self):
if self.is_full.value:
return self.buffer_size.value
else:
return self.index.value
cdef class ProcessSafeRingBufferIndex(RingBufferIndex):
"""Process Safe Ring Buffer Index class
"""
cdef lock
def __init__(self,buffer_size):
super().__init__(buffer_size)
self.lock = Lock()
cdef size_t get_next_index(self):
with self.lock:
return RingBufferIndex.get_next_index(self)
cdef size_t fetch_add(self,size_t N):
with self.lock:
return RingBufferIndex.fetch_add(self,N)
cdef void clear(self):
with self.lock:
RingBufferIndex.clear(self)
cdef size_t get_stored_size(self):
with self.lock:
return RingBufferIndex.get_stored_size(self)
3.2.2 読み書き排他制御
複数explorerプロセスから同時に書き込む (add
) ことは可能だが、learnerプロセスが読み込んでいる(sample
) 途中でデータを書き換えられてしまうと整合性がなくなってしまう。そのため learner プロセスが読み込む (sample
および update_priorities
) 際には、 explorer を排除しなければならない。
複数のexplorerが同時に書き込む必要があるので、単純なロックでは要件を満たせない。クリティカルセクションの中にいるexplorerの数を数えてロックする必要がある。
実装の一部を以下に示す。これらの関数をクリティカルセクションにアクセスする際に利用して排他制御を行っている。
def __init__(self,size,env_dict=None,*,default_dtype=None,logger=None,**kwargs):
# 中略
self.learner_ready = Event()
self.learner_ready.clear()
self.explorer_ready = Event()
self.explorer_ready.set()
self.explorer_count = Value(ctypes.c_size_t,0)
cdef void _lock_explorer(self) except *:
self.explorer_ready.wait() # Wait permission
self.learner_ready.clear() # Block learner
with self.explorer_count.get_lock():
self.explorer_count.value += 1
cdef void _unlock_explorer(self) except *:
with self.explorer_count.get_lock():
self.explorer_count.value -= 1
if self.explorer_count.value == 0:
self.learner_ready.set()
cdef void _lock_learner(self) except *:
self.explorer_ready.clear() # New explorer cannot enter into critical section
self.learner_ready.wait() # Wait until all explorer exit from critical section
cdef void _unlock_learner(self) except *:
self.explorer_ready.set() # Allow workers to enter into critical section
4. cpprbを利用したApe-Xのサンプル実装
Ape-Xの骨格だけで、深層学習のネットワークは含まれていない。
実際に利用する際には、 MyModel
としている部分を書き換えてほしい
出典: https://ymd_h.gitlab.io/cpprb/features/ape-x/
from multiprocessing import Process, Event, SimpleQueue
import time
import gym
import numpy as np
from tqdm import tqdm
from cpprb import ReplayBuffer, MPPrioritizedReplayBuffer
class MyModel:
def __init__(self):
self._weights = 0
def get_action(self,obs):
# Implement action selection
return 0
def abs_TD_error(self,sample):
# Implement absolute TD error
return np.zeros(sample["obs"].shape[0])
@property
def weights(self):
return self._weights
@weights.setter
def weights(self,w):
self._weights = w
def train(self,sample):
# Implement model update
pass
def explorer(global_rb,env_dict,is_training_done,queue):
local_buffer_size = int(1e+2)
local_rb = ReplayBuffer(local_buffer_size,env_dict)
model = MyModel()
env = gym.make("CartPole-v1")
obs = env.reset()
while not is_training_done.is_set():
if not queue.empty():
w = queue.get()
model.weights = w
action = model.get_action(obs)
next_obs, reward, done, _ = env.step(action)
local_rb.add(obs=obs,act=action,rew=reward,next_obs=next_obs,done=done)
if done:
local_rb.on_episode_end()
obs = env.reset()
else:
obs = next_obs
if local_rb.get_stored_size() == local_buffer_size:
local_sample = local_rb.get_all_transitions()
local_rb.clear()
absTD = model.abs_TD_error(local_sample)
global_rb.add(**local_sample,priorities=absTD)
def learner(global_rb,queues):
batch_size = 64
n_warmup = 100
n_training_step = int(1e+4)
explorer_update_freq = 100
model = MyModel()
while global_rb.get_stored_size() < n_warmup:
time.sleep(1)
for step in tqdm(range(n_training_step)):
sample = global_rb.sample(batch_size)
model.train(sample)
absTD = model.abs_TD_error(sample)
global_rb.update_priorities(sample["indexes"],absTD)
if step % explorer_update_freq == 0:
w = model.weights
for q in queues:
q.put(w)
if __name__ == "__main__":
buffer_size = int(1e+6)
env_dict = {"obs": {"shape": 4},
"act": {},
"rew": {},
"next_obs": {"shape": 4},
"done": {}}
n_explorer = 4
global_rb = MPPrioritizedReplayBuffer(buffer_size,env_dict)
is_training_done = Event()
is_training_done.clear()
qs = [SimpleQueue() for _ in range(n_explorer)]
ps = [Process(target=explorer,
args=[global_rb,env_dict,is_training_done,q])
for q in qs]
for p in ps:
p.start()
learner(global_rb,qs)
is_training_done.set()
for p in ps:
p.join()
print(global_rb.get_stored_size())
5. 最後に
cpprb の開発の発端となった友人の開発する tf2rl のApe-X実装にもPRを投げたので、
マージされれば更に高速化するはず。
tf2rl は、以下の紹介記事にもあるとおり、多くの強化学習の手法が実装されたライブラリなので、
こちらもぜひ使ってみてください。
追記
Qiita にエンドユーザー向けの記事を公開しました。
Discussion