🤖

【強化学習】Large Batch Experience Replay (LaBER)

2021/12/02に公開約11,500字

この記事は、強化学習 Advent Calendar 2021の12/3の記事です。

0. はじめに

強化学習のExperience Replayに興味がありZennで論文調査などの記事を書いたり、Replay Bufferライブラリ cpprb を開発したりしています。

https://zenn.dev/ymd_h/articles/c3ba23033a6442
https://zenn.dev/ymd_h/articles/35d88f7d739651

今回、面白そうな論文があったので、その紹介と実装してみましたの記事です。
(cpprbのサイトのページの日本語焼き直しです。)

1. [論文紹介] Large Batch Experience Replay (LaBER)

T. Lahire et al., "Large Batch Experience Replay", CoRR (2021) (arXiv, code)

1.1 前提

強化学習(のoff-policyな手法)では、遷移 (一般には (s_t, a_t, r_t, s_{t+1}, d_t) の組)をReplay Bufferに保存しておき、後から『ランダム』に取り出してニューラルネットワーク等のポリシーを学習させることで、サンプル効率を高めるExperience Replay (経験再生) が広く行われています。

2016年にT. Schaul等[1]は、この『ランダム』な取り出しを、完全な一様分布ではなく、Q学習で最小化するターゲットであるTD誤差に応じて選ぶ Prioritized Experience Replay (PER) を提案し、学習を速くできることを実験的に示しました。

PERは経験的にはうまく働くものも理論的にはなぜ良いのかは完全には明らかにはされてきませんでした。

1.2 提案手法

そこで著者らは、2017年のL. Wang等のSGDの収束速度の研究[2]に基づき、Replay Buffer上のデータを学習させる際の収束速度を最適化するサンプリング確率が p^{\ast} \propto \|\nabla _{\theta} L (Q_{\theta}(s_i, a_i), y_i) \| となることを導出しました。この分布は、最終出力層に活性化関数がなく、損失関数がL2ノルムの際には、p^{\ast} \propto \|\partial L(q_i, y_i)/ \partial q_i \| とTD誤差に一致します。また、完全には一致せずとも \nabla _{\theta} Q_{\theta}(s_i, a_i) が遷移間でおよそ一定である条件下では、TD誤差が十分に良い近似となるだろうと言及しています。

また、著者らはPERの実装上の課題として、一旦サンプルした遷移の重要度(=TD誤差)のみを更新するため、常に重要度が古いことも挙げています。

上記の2つの点を念頭に著者らは、重要度に厳密な値を利用する「Gradient Experience Replay (GER)」と、重要度は近似を用いるが最新の値を常に利用する「Large Batch Experience Replay (LaBER)」を提案し比較しました。

手法の関係性は以下の表のようになります。厳密な重要度計算は計算コストが高いので、最新の値を利用するのは現実的ではないと著者らは位置づけています。

重要度 厳密 近似
古い GER PER
最新 (非現実的) LaBER

では、近似であればサンプルの度に、Replay Buffer内の10^6程度あるような遷移に対して重要度計算をすることができるのかと言えば、やはり計算コストが大きすぎます。
そこで、LaBERでは、一旦少し大きめ(バッチサイズのm倍)の遷移をReplay Bufferから抽出し、その遷移に対して重要度を計算し、重要度に応じてその中から抽出することで計算コストを抑えています。何度も繰り返せば、一旦取り出したm倍バッチサイズの遷移が元のReplay Bufferの分布の縮小になっているはずだという期待に基づいていると解釈できるでしょう。

LaBERで利用する重要度の「近似」は、surrogate priority (代替重要度?) と呼んでいる \hat{p}_i \propto \hat{G}_i = K\| \Sigma (z_i) \partial L(q_i, y_i)/\partial q_i \| を使います。詳細な説明は原論文を見ていただければと思いますが、最終出力層に活性化関数がなく損失関数がL2ノルムの際には、TD誤差になります。

最適な分布に基づいてReplay Bufferからバッチ遷移を抽出する場合の期待値を算出して、surrogate priorityを代入すると以下のようになります。(Bがバッチ、NがReplay Buffer全体の遷移数です。)

\mathbb{E}_{i\sim p} \left[ \nabla _{\theta} L(Q_{\theta}, y_i)\frac{1}{p_{i}}\right] \approx \frac{1}{B} \sum ^{B}_{i=1}\nabla _{\theta}L(Q_{\theta}, y_i)\frac{1}{p_i} = \frac{1}{B}\frac{\sum_{j=1}^N \hat{G}_j}{N}\sum _{i=1}^B\frac{1}{\hat{G}_i}\nabla _{\theta} L(Q_{\theta}, y_i)

本来Replay Buffer全体で計算すべきの \sum_{j=1}^N \hat{G}_j/N をどのように計算するかに課題があり、著者らは3つのバージョンを提案しています。(理論的にはLaBER-mean以外はバイアスが生じます。)

LaBER-mean
m倍のバッチでの値で代替する

\frac{1}{B}\frac{\sum_{j=1}^{mB} \hat{G}_j}{mB}\sum _{i=1}^B\frac{1}{\hat{G}_i}\nabla _{\theta} L(Q_{\theta}, y_i)

LaBER-lazy
単純に無視する

\frac{1}{B}\sum _{i=1}^B\frac{1}{\hat{G}_i}\nabla _{\theta} L(Q_{\theta}, y_i)

LaBER-max
(PERで安定化のために正規化しているのを倣い)バッチの重みが最小になるように最大の値で正規化する

\frac{1}{B}\left(\min _{j\in[1,B]}\hat{G}_j\right)\sum _{i=1}^B\frac{1}{\hat{G}_i}\nabla _{\theta} L(Q_{\theta}, y_i)

1.3 実験

著者らは、Atariのゲームを用いて包括的に実験を行い提案手法の検証をしました。

  • surrogate priorityは近似を用いない本当の確率分布と同様の精度を達成できるのか?
  • LaBERの精度はDQNを上回るのか?
  • m倍のmはどのくらいにしたら良いのか?
  • LaBERの正規化はどれを選べばよいのか?
  • LaBERは設計思想どおり精度を挙げつつ、分散を抑えられるのか?
  • LaBERで採用している重要度の最新化が鍵となっているのか?

以下に実験結果の一例を引用しますが、これらの結果を受け「surrogate priorityは十分によく近似している」「LaBERはDQNより精度が良く計算速度も速い」「mは大きい方が精度が良い」「理論的に適切なLaBER-meanが精度も良い」「LaBERをPER/GERと組み合わせても、計算負荷の割には精度が向上しないので、LaBERの最新重要度の利用が鍵となっている」などと結論づけています。


Atariのゲームにおける手法比較結果。最右列のPER-LaBER/GER-LaBERはPER/GERでm倍バッチを抽出する組み合わせ手法 (原論文より引用)

2. cpprbでの実装と利用方法

cpprbでは、LaBERはReplay Bufferそのものではなく、ReplayBufferクラスと一緒に使う補助的なクラスとして実装しました。
3つの手法がLaBERmeanLaBERlazyLaBERmaxのクラスになっています。

ReplayBufferから取り出した m倍サイズのミニバッチをsurrogate priorityとともに受け取りサブ・サンプリングを実施します。

import cpprb

# LaBER-mean 用のクラス作成
laber = cpprb.LaBERmean(batch_size = 32, m = 4)

# Replay Buffer の作成などいつもの強化学習
rb = cpprb.ReplayBuffer(1e6, { ... })
...

# Replay Buffer からは、m倍のバッチサイズをまず取り出す。
sample = rb.sample(32 * 4)

# 取り出した遷移に対するsurrogate priorityを計算する
TD = ...

# surrogate priority に基づいてサブサンプリングし、インデックスと重みを得る。
index_weight = laber(priorities=TD)
index = index_weight["indexes"]
weight = index_weight["weights"]

# observation / action などの遷移の値自体を一緒に渡せば、
# 選ばれたインデックスに該当する値を一緒に返してくれるので、お好みで使い分けることができる
transitions = laber(priorities=TD, obs=sample["obs"], act=sample["act"])
index = transitions["indexes"]
weight = transitions["weights"]
obs = transitions["obs"]
act = transitions["act"]

3. サンプルコード

https://github.com/ymd-h/cpprb/blob/master/example/dqn-laber.py
# Example for Large Batch Experience Replay (LaBER)
# Ref: https://arxiv.org/abs/2110.01528

import os
import datetime

import numpy as np

import gym

import tensorflow as tf
from tensorflow.keras.models import Sequential, clone_model
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.summary import create_file_writer


from cpprb import ReplayBuffer, LaBERmean


gamma = 0.99
batch_size = 64

N_iteration = int(1e+6)
target_update_freq = 10000
eval_freq = 1000

egreedy = 1.0
decay_egreedy = lambda e: max(e*0.99, 0.1)


# Use 4 times larger batch for initial uniform sampling
# Use LaBER-mean, which is the best variant
m = 4
LaBER = LaBERmean(batch_size, m)


# Log
dir_name = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = os.path.join("logs", dir_name)
writer = create_file_writer(logdir + "/metrics")
writer.set_as_default()


# Env
env = gym.make('CartPole-v1')
eval_env = gym.make('CartPole-v1')

# For CartPole: input 4, output 2
model = Sequential([Dense(64,activation='relu',
                          input_shape=(env.observation_space.shape)),
                    Dense(env.action_space.n)])
target_model = clone_model(model)


# Loss Function

@tf.function
def Huber_loss(absTD):
    return tf.where(absTD > 1.0, absTD, tf.math.square(absTD))

@tf.function
def MSE(absTD):
    return tf.math.square(absTD)

loss_func = Huber_loss


optimizer = Adam()


buffer_size = 1e+6
env_dict = {"obs":{"shape": env.observation_space.shape},
            "act":{"shape": 1,"dtype": np.ubyte},
            "rew": {},
            "next_obs": {"shape": env.observation_space.shape},
            "done": {}}

# Nstep
nstep = 3
# nstep = False

if nstep:
    Nstep = {"size": nstep, "rew": "rew", "next": "next_obs"}
    discount = tf.constant(gamma ** nstep)
else:
    Nstep = None
    discount = tf.constant(gamma)


rb = ReplayBuffer(buffer_size,env_dict,Nstep=Nstep)


@tf.function
def Q_func(model,obs,act,act_shape):
    return tf.reduce_sum(model(obs) * tf.one_hot(act,depth=act_shape), axis=1)

@tf.function
def DQN_target_func(model,target,next_obs,rew,done,gamma,act_shape):
    return gamma*tf.reduce_max(target(next_obs),axis=1)*(1.0-done) + rew

@tf.function
def Double_DQN_target_func(model,target,next_obs,rew,done,gamma,act_shape):
    """
    Double DQN: https://arxiv.org/abs/1509.06461
    """
    act = tf.math.argmax(model(next_obs),axis=1)
    return gamma*tf.reduce_sum(target(next_obs)*tf.one_hot(act,depth=act_shape), axis=1)*(1.0-done) + rew


target_func = Double_DQN_target_func



def evaluate(model,env):
    obs = env.reset()
    total_rew = 0

    while True:
        Q = tf.squeeze(model(obs.reshape(1,-1)))
        act = np.argmax(Q)
        obs, rew, done, _ = env.step(act)
        total_rew += rew

        if done:
            return total_rew

# Start Experiment

observation = env.reset()

# Warming up
for n_step in range(100):
    action = env.action_space.sample() # Random Action
    next_observation, reward, done, info = env.step(action)
    rb.add(obs=observation,
           act=action,
           rew=reward,
           next_obs=next_observation,
           done=done)
    observation = next_observation
    if done:
        env.reset()
        rb.on_episode_end()


n_episode = 0
observation = env.reset()
for n_step in range(N_iteration):

    if np.random.rand() < egreedy:
        action = env.action_space.sample()
    else:
        Q = tf.squeeze(model(observation.reshape(1,-1)))
        action = np.argmax(Q)

    egreedy = decay_egreedy(egreedy)

    next_observation, reward, done, info = env.step(action)
    rb.add(obs=observation,
           act=action,
           rew=reward,
           next_obs=next_observation,
           done=done)
    observation = next_observation

    # Uniform sampling
    sample = rb.sample(batch_size * m)

    with tf.GradientTape() as tape:
        tape.watch(model.trainable_weights)
        Q =  Q_func(model,
                    tf.constant(sample["obs"]),
                    tf.constant(sample["act"].ravel()),
                    tf.constant(env.action_space.n))
        target_Q = tf.stop_gradient(target_func(model,target_model,
                                                tf.constant(sample['next_obs']),
                                                tf.constant(sample["rew"].ravel()),
                                                tf.constant(sample["done"].ravel()),
                                                discount,
                                                tf.constant(env.action_space.n)))
        tf.summary.scalar("Target Q", data=tf.reduce_mean(target_Q), step=n_step)
        absTD = tf.math.abs(target_Q - Q)

        # Sub-sample according to surrogate priorities
        #   When loss is L2 or Huber, and no activation at the last layer,
        #   |TD| is surrogate priority.
        sample = LaBER(priorities=absTD)
        indexes = tf.constant(sample["indexes"])
        weights = tf.constant(sample["weights"])

        absTD = tf.gather(absTD, indexes)
        assert absTD.shape == weights.shape, f"BUG: absTD.shape: {absTD.shae}, weights.shape {weights.shape}"

        loss = tf.reduce_mean(loss_func(absTD)*weights)

    grad = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grad, model.trainable_weights))
    tf.summary.scalar("Loss vs training step", data=loss, step=n_step)


    if done:
        env.reset()
        rb.on_episode_end()
        n_episode += 1

    if n_step % target_update_freq == 0:
        target_model.set_weights(model.get_weights())

    if n_step % eval_freq == eval_freq-1:
        eval_rew = evaluate(model,eval_env)
        tf.summary.scalar("episode reward vs training step",data=eval_rew,step=n_step)

4. おわりに

面白そうな手法LaBERの論文を読んで、自作のReplay Bufferライブラリcpprbに実装をし、簡単にですが紹介記事を書きました。

(本当は自分で実験して確認までしたかったのですが、記事の下書きが1ヶ月以上そのままでいつまでも完成しそうになかったので、Advent Calenderを機に、ここまででの公開としました。)

興味を持ってくれた人がいれば、cpprbを使ってもらえると私が喜びます。

https://qiita.com/ymd_h/items/505c607c40cf3e42d080
https://qiita.com/ymd_h/items/ac9e3f1315d56a1b2718
https://zenn.dev/ymd_h/articles/61029e7a32542b
脚注
  1. T. Schaul et al., "Prioritized Experience Replay", ICLR (2016) (arXiv) ↩︎

  2. L. Wang et al., "Accelerating Deep Neural Network Training with Inconsistent Stochastic Gradient Descent", Neural Networks (2017) Vol.93, 219-229 (arXiv) ↩︎

Discussion

ログインするとコメントできます