🎉

[強化学習] Ape-Xの高速化を実現

2021/01/16に公開

記事としては Zenn 初投稿。スクラップ記事で色々まとめたことを元に開発した。

https://zenn.dev/ymd_h/scraps/e68d0ecd89c12a

TL;DR

  • 強化学習のExperience Replay (経験再生) のためのライブラリ cpprb を開発してるよ
  • 共有メモリによるデータ共有と効率的なロックの実装で、Ape-X を高速化したよ
  • cpprb を利用する際の Ape-X のサンプル実装もあるよ

1. はじめに

2019年の1月頃から、強化学習のExperience Replay (経験再生) のためのバッファライブラリcpprbを開発している。

https://qiita.com/ymd_h/items/21071d7778cfb3cd596a

上の紹介記事にも書いたが、自由度と効率性を重視して開発を行っている。

保存するトランザクションの値を柔軟に設定することが可能で、また Prioritized Experience Replay に利用するSegment TreeをC++で実装することで、Pythonで実装するよりも高速化を実現している。

使ってみていいなと思ったら、以下のレポジトリにスターをつけてくれると嬉しい。

https://gitlab.com/ymd_h/cpprb
https://github.com/ymd-h/cpprb

2. 今回の課題

TensorFlowPyTorch などの深層学習のライブラリはGPUに最適化されている一方、強化学習を行う環境(シミュレーターなど)はGPUで高速に実行できるものばかりではなく、強化学習を行う際の律速となることがある。

そんな条件下では、強化学習にかかる時間を短縮させるひとつの手段として、Ape-X のように環境の探索をネットワークの学習と分離し、かつ複数同時に実行させる方法が取られている。

並列処理を実現するにあたって、プロセス間でデータを高速かつ整合性を持ってやり取りすることが必要になってくる。 (Python (CPython) では GIL があるため、並列に複数の処理を実行するには、マルチ・スレッドではなく、マルチ・プロセスを利用する。)

プロセス間でデータを受け渡す単純な方法は、プロキシ (multiprocessing.managers.SyncManager) や、キュー (multiprocessing.Queue) などを利用する方法だが、データ量が大きくなると遅くなりやすい。

またデータの整合性を担保するために、グローバル・バッファを操作する際にロックをかける必要があるが、単純にバッファ全体をロックすると複数プロセスがロック待ちをする時間が非常に長くなりやすい。

# explorerの単純な実装例
if local_buffer.get_stored_size() > local_size:
    local_sample = local_buffer.get_all_transitions()
    local_buffer.clear()

    with lock:
        global_buffer.add(**local_sample)
# learnerの単純な実装例
with lock:
    sample = global_buffer.sample(batch_size)
    TD = model.compute_TD_error(sample)
    global_buffer.update_priorities(sample["indexes"],np.abs(TD))

まとめると、Ape-Xの高速化には、以下の点の改善が必要だと考える。

  1. 高速なプロセス間通信
  2. データの整合性を保ちつつロックを効率化

3. cpprbの対応

これまでの ReplayBuffer PrioritizedReplayBuffer クラスに加え、マルチプロセスサポート機能を実装した MPReplayBuffer MPPrioritizedReplayBuffer クラスを新たに実装した。

詳細は公式ドキュメントに記載しているが、この記事では一部を抜粋して紹介する。

3.1 共有メモリ

プロセス間のデータを高速に同期するために、内部データを共有メモリに載せる方式を採用した。
メモリの物理的に同じ領域を複数のプロセスから同時に利用する方式のため、無駄なデータコピーを削減でき高速化が見込まれる。

これまで numpy.ndarray で内部データを保持していたので、互換性を持つように共有メモリ上に ndarray のビューを配置する次のようなクラスを実装した。
(Cython を利用しているが、この部分実装には特に必要なかったようにも思う。)

from multiprocessing.sharedctypes import RawArray
import numpy as no

cdef class SharedBuffer:
    cdef data
    cdef view
    def __init__(self,shape,dtype,data=None):
        ctype = np.ctypeslib.as_ctypes_type(dtype)
        len = int(np.array(shape,copy=False,dtype="int").prod())
        self.data = data or RawArray(ctype,len)
        self.view = np.ctypeslib.as_array(self.data)
        self.view.shape = shape

    def __getitem__(self,key):
        return self.view[key]

    def __setitem__(self,key,value):
        self.view[key] = value

    def __reduce__(self):
        return (SharedBuffer,(self.view.shape,self.view.dtype,self.data))

3.2 効率的なロック

次の図は Prioritized Replay Buffer のデータ構造である。
バッファ全体ではなく、クリティカルセクションだけをロックする方針でロックを実装していく。


出典: https://ymd_h.gitlab.io/cpprb/features/ape-x/

3.2.1 同時書き込み

cpprb では Replay Buffer はリングバッファで実装されており、一連の書き込みは異なるアドレスに順番に書き込まれる。1つのプロセスが書き込んでいる間にバッファ全体をロックする必要はなく、書き込み先インデックスを適切にロックして参照・増加させれば、複数のプロセスが同時に異なるアドレスに書き込んでも問題ない。

ReplayBuffer クラスの中にばらばらで実装されていたインデックス操作を RingBufferIndex と切り出し、それを継承する形で操作時にロックする ProcessSafeRingBufferIndex を実装した。 (副作用として、普段の ReplayBuffer もインデックスを共有メモリ上に配置することになったが、何のロックもアトミックな操作もしていないため速度面での影響は実験した限り無かった。)

cdef class RingBufferIndex:
    """Ring Buffer Index class
    """
    cdef index
    cdef buffer_size
    cdef is_full

    def __init__(self,buffer_size):
        self.index = RawValue(ctypes.c_size_t,0)
        self.buffer_size = RawValue(ctypes.c_size_t,buffer_size)
        self.is_full = RawValue(ctypes.c_int,0)

    cdef size_t get_next_index(self):
        return self.index.value

    cdef size_t fetch_add(self,size_t N):
        """
        Add then return original value

        Parameters
        ----------
        N : size_t
            value to add

        Returns
        -------
        size_t
            index before add
        """
        cdef size_t ret = self.index.value
        self.index.value += N

        if self.index.value >= self.buffer_size.value:
            self.is_full.value = 1

        while self.index.value >= self.buffer_size.value:
            self.index.value -= self.buffer_size.value

        return ret

    cdef void clear(self):
        self.index.value = 0
        self.is_full.value = 0

    cdef size_t get_stored_size(self):
        if self.is_full.value:
            return self.buffer_size.value
        else:
            return self.index.value

cdef class ProcessSafeRingBufferIndex(RingBufferIndex):
    """Process Safe Ring Buffer Index class
    """
    cdef lock

    def __init__(self,buffer_size):
        super().__init__(buffer_size)
        self.lock = Lock()

    cdef size_t get_next_index(self):
        with self.lock:
            return RingBufferIndex.get_next_index(self)

    cdef size_t fetch_add(self,size_t N):
        with self.lock:
            return RingBufferIndex.fetch_add(self,N)

    cdef void clear(self):
        with self.lock:
            RingBufferIndex.clear(self)

    cdef size_t get_stored_size(self):
        with self.lock:
            return RingBufferIndex.get_stored_size(self)

3.2.2 読み書き排他制御

複数explorerプロセスから同時に書き込む (add) ことは可能だが、learnerプロセスが読み込んでいる(sample) 途中でデータを書き換えられてしまうと整合性がなくなってしまう。そのため learner プロセスが読み込む (sample および update_priorities) 際には、 explorer を排除しなければならない。

複数のexplorerが同時に書き込む必要があるので、単純なロックでは要件を満たせない。クリティカルセクションの中にいるexplorerの数を数えてロックする必要がある。

実装の一部を以下に示す。これらの関数をクリティカルセクションにアクセスする際に利用して排他制御を行っている。

    def __init__(self,size,env_dict=None,*,default_dtype=None,logger=None,**kwargs):
        
	# 中略
	
        self.learner_ready = Event()
        self.learner_ready.clear()
        self.explorer_ready = Event()
        self.explorer_ready.set()

        self.explorer_count = Value(ctypes.c_size_t,0)

    cdef void _lock_explorer(self) except *:
        self.explorer_ready.wait() # Wait permission
        self.learner_ready.clear()  # Block learner
        with self.explorer_count.get_lock():
            self.explorer_count.value += 1

    cdef void _unlock_explorer(self) except *:
        with self.explorer_count.get_lock():
            self.explorer_count.value -= 1
        if self.explorer_count.value == 0:
            self.learner_ready.set()

    cdef void _lock_learner(self) except *:
        self.explorer_ready.clear() # New explorer cannot enter into critical section
        self.learner_ready.wait() # Wait until all explorer exit from critical section

    cdef void _unlock_learner(self) except *:
        self.explorer_ready.set() # Allow workers to enter into critical section

4. cpprbを利用したApe-Xのサンプル実装

Ape-Xの骨格だけで、深層学習のネットワークは含まれていない。
実際に利用する際には、 MyModel としている部分を書き換えてほしい

出典: https://ymd_h.gitlab.io/cpprb/features/ape-x/

from multiprocessing import Process, Event, SimpleQueue
import time

import gym
import numpy as np
from tqdm import tqdm

from cpprb import ReplayBuffer, MPPrioritizedReplayBuffer


class MyModel:
    def __init__(self):
        self._weights = 0

    def get_action(self,obs):
        # Implement action selection
        return 0

    def abs_TD_error(self,sample):
        # Implement absolute TD error
        return np.zeros(sample["obs"].shape[0])

    @property
    def weights(self):
        return self._weights

    @weights.setter
    def weights(self,w):
        self._weights = w

    def train(self,sample):
        # Implement model update
        pass


def explorer(global_rb,env_dict,is_training_done,queue):
    local_buffer_size = int(1e+2)
    local_rb = ReplayBuffer(local_buffer_size,env_dict)

    model = MyModel()
    env = gym.make("CartPole-v1")

    obs = env.reset()
    while not is_training_done.is_set():
        if not queue.empty():
            w = queue.get()
            model.weights = w

        action = model.get_action(obs)
        next_obs, reward, done, _ = env.step(action)
        local_rb.add(obs=obs,act=action,rew=reward,next_obs=next_obs,done=done)

        if done:
            local_rb.on_episode_end()
            obs = env.reset()
        else:
            obs = next_obs

        if local_rb.get_stored_size() == local_buffer_size:
            local_sample = local_rb.get_all_transitions()
            local_rb.clear()

            absTD = model.abs_TD_error(local_sample)
            global_rb.add(**local_sample,priorities=absTD)


def learner(global_rb,queues):
    batch_size = 64
    n_warmup = 100
    n_training_step = int(1e+4)
    explorer_update_freq = 100

    model = MyModel()

    while global_rb.get_stored_size() < n_warmup:
        time.sleep(1)

    for step in tqdm(range(n_training_step)):
        sample = global_rb.sample(batch_size)

        model.train(sample)
        absTD = model.abs_TD_error(sample)
        global_rb.update_priorities(sample["indexes"],absTD)

        if step % explorer_update_freq == 0:
            w = model.weights
            for q in queues:
                q.put(w)


if __name__ == "__main__":
    buffer_size = int(1e+6)
    env_dict = {"obs": {"shape": 4},
                "act": {},
                "rew": {},
                "next_obs": {"shape": 4},
                "done": {}}
    n_explorer = 4

    global_rb = MPPrioritizedReplayBuffer(buffer_size,env_dict)

    is_training_done = Event()
    is_training_done.clear()

    qs = [SimpleQueue() for _ in range(n_explorer)]
    ps = [Process(target=explorer,
                  args=[global_rb,env_dict,is_training_done,q])
          for q in qs]

    for p in ps:
        p.start()

    learner(global_rb,qs)
    is_training_done.set()

    for p in ps:
        p.join()

    print(global_rb.get_stored_size())

5. 最後に

cpprb の開発の発端となった友人の開発する tf2rl のApe-X実装にもPRを投げたので、
マージされれば更に高速化するはず。

https://github.com/keiohta/tf2rl/pull/122

tf2rl は、以下の紹介記事にもあるとおり、多くの強化学習の手法が実装されたライブラリなので、
こちらもぜひ使ってみてください。

https://keiohta.github.io/posts/python/tf2rl/
https://qiita.com/ymd_h/items/0f1be143546ccd4f4c33

追記

Qiita にエンドユーザー向けの記事を公開しました。

https://qiita.com/ymd_h/items/ac9e3f1315d56a1b2718

Discussion