🤖

Double DQNで深層強化学習

2023/12/07に公開

team411のSugiyamaです!
team411は電通大の技術系サークルで、プロダクト開発、ソリューション開発などを行っております。
1年生は研修、2〜3年生はマネジメント、4年〜M2年は研究知識の応用を目標にしています!

今回は深層強化学習の1つ、Double DQNについて実装したので解説したいとおもいます。

Double DQN

Double DQNの概要

元論文:Deep Reinforcement Learning with Double Q-learning
https://arxiv.org/abs/1509.06461

Double DQNとは従来のDQN(Deep Q Network)に

  • 行動決定をQ Network
  • Q関数の評価にQ TargetNetwork
    という2つのネットワークを導入することで、高い精度を実現したものです。

Q ネットワークの目標値(教師データ)を
Y^{\rm DoubleDQN}_t=R_{t+1}+\gamma Q(S_{t+1}, {\rm argmax}_a Q(S_{t+1}, a; \theta_t); \theta^{-})
と更新しています!

ここでつかわれる文字について

  • \theta: メインネットワークの重み
  • \theta^{-}: ターゲットネットワークの重み
  • Q(s_t, a; \theta): Q関数(ネットワークの重みを\thetaとするとき、状態s_tにおける、行動aの行動価値(累積報酬))
  • R_t: ステップtにおける即時報酬
  • \gamma: 割引率
    • 割引率: 即時報酬か、累積報酬かどちらを重視するか(1 < \gamma < 0)

全体のステップについて

1, Actionの選択

Agentは環境にたいして行うActionを選択します。
探索をおこないつつ、より価値観数が高い行動を選択します。

epsilon-greedy方策の場合、

\epsilon > {\rm random}()のとき、
アクションはランダム、すなわち
a_t = {\rm random({\rm actionSize})}

\epsilon < {\rm random}()のとき、
アクションはMain Q NetworkのQvaluesが最大のもの、すなわち
a_t = {\rm argmax}_a Q(s_t, a ; \theta)
が選択されます。

2, 次状態の取得

環境に対してActionを行うと、Agentは次状態s_{t+1}と、その行動の即時報酬r_tを得ます。

3, リプレイバッファへの記録

状態s_t、Actiona_t、次状態s_{t+1}、即時報酬r_tの情報は、リプレイバッファに記録します。
この記録をいくつかサンプリングし、過去の複数データでQ関数を評価してあげることで、学習の効率をあげます。

3, Qネットワークの更新

3-1, リプレイバッファからのサンプリング

Batch sizeの数だけ、リプレイバッファから状態\bm{s}、Action\bm{a}、次状態\bm{s_{next}}、即時報酬\bm{r}を取得します。

3-2, 目標値(教師データ)の作成

  1. まず次状態における、もっともQ値が高い行動を取得します
    \bm{a_{main, next}}={\rm argmax}_a\bm{Q}(\bm{s_{next}}, a; \theta)
    ポイント: ここではQメインネットワークを利用します

  2. 先程取得したActionで、次状態における、ターゲットネットワークでのQ値を取得します
    \bm{Q}(\bm{s_{next}}, \bm{a_{main, next}}; \theta^{-})
    ポイント: ここではQターゲットネットワークを利用します

  3. このQ値をつかって、目標値(教師データ)を作成します
    \bm{Q^*}(\bm{s}, \bm{a}) =\bm{r} + \gamma \bm{Q}(\bm{s_{next}}, \bm{a_{main, next}}; \theta^{-})
    ポイント: \bm{Q^*}(\bm{s}, \bm{a})は、現在の状態\bm{s}、その時した行動\bm{a}に対する目標値です

  4. この目標値から、損失をもとめます
    {Loss}=|Q^*(s_t, a_t) - Q(s_t, a_t; \theta)|^2
    ポイント: 現在の状態s、その時の行動aのQメインネットワークのQ値との差分をとります

  5. この損失で、Qメインネットワークを学習させます

  6. 一定間隔でターゲットネットワークの重みを、メインネットワークに更新します
    \theta^- \leftarrow \theta

Tensorflowによる実装

工夫点: 教師データの学習による前処理

損失の注意

損失に注意してください。
この損失はリプレイバッファから取り出した状態sにおける行動aQ^*から、メインネットワークを学習させます。
リプレイバッファから取り出した行動がa^{n=0}のとき、
{Loss^{n=0}}=|Q^*(s_t, a^{n=0}) - Q(s_t, a^{n=0}; \theta)|^2
ですが、
{Loss^{n\ne0}}=0
となります。

テンソル演算による回避

先述した損失の場合、シンプルにfit関数をつかって学習をおこなうわけには行かず、学習に関する処理をオーバーライドする必要があります。
またバッチごとにtrain_on_batchをつかってループで回すと、重くなる原因となります。
これは面倒なので、テンソル演算で回避します。
\bm{targets} = (\bm{1} - \bm{a})\bm{Q}(\bm{s}, \bm{a};\theta) + \bm{r} + \gamma \cdot \bm{a} \bm{Q^*}(\bm{s}, \bm{a})
要するに、

  • aに該当しない目標値は、メインネットワークのものとおなじ(損失0)
  • aに該当する目標値はQ^*(損失|Q^*(s, a) - Q(s, a; \theta)|^2)

ソースコード

リプレイバッファ

class ExperienceBuffer:
    def __init__(self, memory_maxlen, state_shape=(state_size, ), action_shape=(action_size,)):
        self.memory_maxlen = memory_maxlen
        self.state_shape = state_shape
        self.action_shape = action_shape
        self.states = tf.Variable(tf.zeros((memory_maxlen,) + state_shape), trainable=False)
        self.actions = tf.Variable(tf.zeros((memory_maxlen,) + action_shape), trainable=False)
        self.rewards = tf.Variable(tf.zeros((memory_maxlen,) + action_shape), trainable=False)
        self.next_states = tf.Variable(tf.zeros((memory_maxlen,) + state_shape), trainable=False)
        self.index = tf.Variable(0, trainable=False)
        
    def length(self):
        return tf.minimum(self.index, self.memory_maxlen)

    def add(self, state, action, reward, next_state):
        idx = self.index % self.memory_maxlen
        
        self.states[idx].assign(state)
        
        action_one_hot = tf.one_hot(action, action_size)
        self.actions[idx].assign(action_one_hot)
        
        reward_one_hot = action_one_hot * reward
        self.rewards[idx].assign(reward_one_hot)
        
        self.next_states[idx].assign(next_state)
        
        self.index.assign_add(1)

    def sample(self, batch_size):
        indices = tf.random.uniform((batch_size,), minval=0, maxval=self.index, dtype=tf.int32)
        return (tf.gather(self.states, indices),
                tf.gather(self.actions, indices),
                tf.gather(self.rewards, indices),
                tf.gather(self.next_states, indices))

エージェント

class DoubleDQN():
    def __init__(self,
                state_size,
                action_size,
                learning_rate,
                gamma,
                memory_maxlen,
                batch_size,
                epsilon,
                epsilon_decay,
                epsilon_min,
                update_target_frequency,
                node
                ):
        # Initialize
        self.state_size = state_size
        self.action_size = action_size
        self.batch_size = batch_size
        self.step = 0   # 学習回数
        self.node = node
        self.history = 0
        
        # 方策に関するハイパーパラメタ
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.action_size = action_size
        self.epsilon_min = epsilon_min
        
        # 学習に関するハイパーパラメタ
        self.learning_rate = learning_rate
        self.gamma = gamma
        
        # リプレイメモリ
        self.replay_memory = ExperienceBuffer(memory_maxlen)
        
        # ネットワークの構築
        self.main_network : tf.keras.Model = self._build_network()
        self.target_network : tf.keras.Model = self._build_network()
        # target_networkの重みは、mainネットワークの重みで初期化
        self.target_network.set_weights(self.main_network.get_weights())
        # ターゲットネットの更新頻度
        self.update_target_frequency = update_target_frequency
        
        
    # 送信後はこの関数を実行して、結果を学習する
    def observe(self, state, action, reward, next_state):
        # 1, 観測データをリプレイメモリに追加
        self.replay_memory.add(state, action, reward, next_state)
        # 2, 観測データで学習
        # Info: リプレイメモリが十分なとき(バッチサイズより大きい)のとき、学習する
        if self.replay_memory.length() > self.batch_size:
            self._train_agent()
    
    # 送信前に、この関数を実行して、行動を取得する
    def get_action(self, next_state):
        # ステップの更新
        self.step += 1
        
        # リプレイメモリが不十分なときは、ランダムに行動を選択
        if self.replay_memory.length() < self.batch_size:
            return np.random.randint(self.action_size)

        # ---Policy---
        # Info: ここでは、e-greedy policyを実装
        if self.epsilon > np.random.rand():
            # ランダムに行動を選択
            action = np.random.randint(self.action_size)
        else:
            # メインネットワークから行動価値を計算
            qvalues = self.main_network.predict(next_state)[0]
            add_qinfo(self.node, self.step, qvalues, self.history.history['loss'][-1])
            
            action = np.argmax(qvalues)
        
        # epsilonの更新式
        if self.epsilon > self.epsilon_min:
            self.epsilon = (1 / (self.step / self.epsilon_decay + 1))
        # ------------
        
        return action

    # 学習を行う関数
    def _train_agent(self):
        # バッチの取得
        states, actions, rewards, next_states = self.replay_memory.sample(self.batch_size)
        
        # Info: これはバッチ全てに対して行動価値を計算している
        # 1, メインネットワークからQ値を取得
        main_q_values = self.main_network(states)
        
        # 2, メインネットワークに次状態を入力して、目標値(target)の計算で利用するの行動を取得
        main_next_q_values = self.main_network(next_states)
        main_next_actions = tf.one_hot(tf.argmax(main_next_q_values, axis=1), depth=self.action_size)
        
        # 3, ターゲットネットワークに次状態を入力して、目標値(target)の計算で利用するQ値を取得
        target_next_q_values = self.target_network(next_states)

        # 4, ターゲットネットワークから教師データを作成する
        # Creating a range for the row indices
        # 4-1, 目標値の計算
        targets =   tf.multiply(1 - actions, main_q_values) + rewards + self.gamma * (actions * tf.reshape(tf.norm(main_next_actions * target_next_q_values, axis=1), [-1, 1]))
        # loss = main_q_values - targets
        
        # print("DEBUG: Qmain(s_t+1, a): ", main_next_q_values[0])
        # print("DEBUG: argmax a Qmain(s_t+1, a): ", main_next_actions[0])
        
        # print("DEBUG: Qtar(s_t+1, a): ", target_next_q_values[0])
        # print("DEBUG: Qtar(s_t+1, argmax a): ", (main_next_actions * target_next_q_values)[0])
        
        # print("DEBUG: reward s_t: ", rewards[0])
        # print("DEBUG: formated Qtar(s_t+1, argmax a): ", (actions * tf.reshape(tf.norm(main_next_actions * target_next_q_values, axis=1), [-1, 1]))[0])
        # print("DEBUG: Q(s_t, a): ", main_q_values[0])
        # print("DEBUG: tar: ", targets[0])
        # print("DEBUG: loss: ", loss[0])
            
        # 5, targetを教師データとして、メインネットワークを学習
        # label: targets
        # output: main_q_values
        self.history = self.main_network.fit(
            states,
            targets,
            batch_size=self.batch_size,
            epochs=1,
            verbose=0,
        )
        
        # 5, ターゲットネットワークの重みは、定期的に更新
        if self.step % self.update_target_frequency == 0:
            self.target_network.set_weights(self.main_network.get_weights())
        
    # メインネットワークのビルド
    def _build_network(self):
        input_layer = tf.keras.Input(shape=(self.state_size,))
        dense1 = tf.keras.layers.Dense(32, activation='relu')(input_layer)
        dense2 = tf.keras.layers.Dense(32, activation='relu')(dense1)
        output_layer = tf.keras.layers.Dense(self.action_size, activation='softmax')(dense2)
        
        model = tf.keras.Model(inputs=input_layer, outputs=output_layer)
        
        model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(learning_rate=self.learning_rate))
        
        model.summary()
        return model

使い方

def DQN_test():
    env = TekitoEnv()
    env.reset()
    state = np.reshape(np.random.uniform(-1, 1, state_size), [-1, 1])
    for episode in range(episode_count):
        
        action = dqn_agent.get_action(state)
        
        next_state, reward, _, _ = env.step(action)
        next_state = np.reshape(next_state, [-1, 1])
        
        dqn_agent.observe(state, action, reward, next_state)
        
        state = next_state

まとめ

実装に自身がないので、間違いがあったらコメントで指摘をよろしくお願いします♥

参考

Discussion