Double DQNで深層強化学習
team411のSugiyamaです!
team411は電通大の技術系サークルで、プロダクト開発、ソリューション開発などを行っております。
1年生は研修、2〜3年生はマネジメント、4年〜M2年は研究知識の応用を目標にしています!
今回は深層強化学習の1つ、Double DQNについて実装したので解説したいとおもいます。
Double DQN
Double DQNの概要
元論文:Deep Reinforcement Learning with Double Q-learning
Double DQNとは従来のDQN(Deep Q Network)に
- 行動決定をQ Network
- Q関数の評価にQ TargetNetwork
という2つのネットワークを導入することで、高い精度を実現したものです。
Q ネットワークの目標値(教師データ)を
と更新しています!
ここでつかわれる文字について
-
: メインネットワークの重み\theta -
: ターゲットネットワークの重み\theta^{-} -
: Q関数(ネットワークの重みをQ(s_t, a; \theta) とするとき、状態\theta における、行動s_t の行動価値(累積報酬))a -
: ステップR_t における即時報酬t -
: 割引率\gamma - 割引率: 即時報酬か、累積報酬かどちらを重視するか
(1 < \gamma < 0)
- 割引率: 即時報酬か、累積報酬かどちらを重視するか
全体のステップについて
1, Actionの選択
Agentは環境にたいして行うActionを選択します。
探索をおこないつつ、より価値観数が高い行動を選択します。
epsilon-greedy方策の場合、
アクションはランダム、すなわち
アクションはMain Q NetworkのQvaluesが最大のもの、すなわち
が選択されます。
2, 次状態の取得
環境に対してActionを行うと、Agentは次状態
3, リプレイバッファへの記録
状態
この記録をいくつかサンプリングし、過去の複数データでQ関数を評価してあげることで、学習の効率をあげます。
3, Qネットワークの更新
3-1, リプレイバッファからのサンプリング
Batch sizeの数だけ、リプレイバッファから状態
3-2, 目標値(教師データ)の作成
-
まず次状態における、もっともQ値が高い行動を取得します
\bm{a_{main, next}}={\rm argmax}_a\bm{Q}(\bm{s_{next}}, a; \theta)
ポイント: ここではQメインネットワークを利用します -
先程取得したActionで、次状態における、ターゲットネットワークでのQ値を取得します
\bm{Q}(\bm{s_{next}}, \bm{a_{main, next}}; \theta^{-})
ポイント: ここではQターゲットネットワークを利用します -
このQ値をつかって、目標値(教師データ)を作成します
\bm{Q^*}(\bm{s}, \bm{a}) =\bm{r} + \gamma \bm{Q}(\bm{s_{next}}, \bm{a_{main, next}}; \theta^{-})
ポイント: は、現在の状態\bm{Q^*}(\bm{s}, \bm{a}) 、その時した行動\bm{s} に対する目標値です\bm{a} -
この目標値から、損失をもとめます
{Loss}=|Q^*(s_t, a_t) - Q(s_t, a_t; \theta)|^2
ポイント: 現在の状態 、その時の行動s のQメインネットワークのQ値との差分をとりますa -
この損失で、Qメインネットワークを学習させます
-
一定間隔でターゲットネットワークの重みを、メインネットワークに更新します
\theta^- \leftarrow \theta
Tensorflowによる実装
工夫点: 教師データの学習による前処理
損失の注意
損失に注意してください。
この損失はリプレイバッファから取り出した状態
リプレイバッファから取り出した行動が
ですが、
となります。
テンソル演算による回避
先述した損失の場合、シンプルにfit関数をつかって学習をおこなうわけには行かず、学習に関する処理をオーバーライドする必要があります。
またバッチごとにtrain_on_batchをつかってループで回すと、重くなる原因となります。
これは面倒なので、テンソル演算で回避します。
要するに、
-
に該当しない目標値は、メインネットワークのものとおなじ(損失0)a -
に該当する目標値はa (損失Q^* )|Q^*(s, a) - Q(s, a; \theta)|^2
ソースコード
リプレイバッファ
class ExperienceBuffer:
def __init__(self, memory_maxlen, state_shape=(state_size, ), action_shape=(action_size,)):
self.memory_maxlen = memory_maxlen
self.state_shape = state_shape
self.action_shape = action_shape
self.states = tf.Variable(tf.zeros((memory_maxlen,) + state_shape), trainable=False)
self.actions = tf.Variable(tf.zeros((memory_maxlen,) + action_shape), trainable=False)
self.rewards = tf.Variable(tf.zeros((memory_maxlen,) + action_shape), trainable=False)
self.next_states = tf.Variable(tf.zeros((memory_maxlen,) + state_shape), trainable=False)
self.index = tf.Variable(0, trainable=False)
def length(self):
return tf.minimum(self.index, self.memory_maxlen)
def add(self, state, action, reward, next_state):
idx = self.index % self.memory_maxlen
self.states[idx].assign(state)
action_one_hot = tf.one_hot(action, action_size)
self.actions[idx].assign(action_one_hot)
reward_one_hot = action_one_hot * reward
self.rewards[idx].assign(reward_one_hot)
self.next_states[idx].assign(next_state)
self.index.assign_add(1)
def sample(self, batch_size):
indices = tf.random.uniform((batch_size,), minval=0, maxval=self.index, dtype=tf.int32)
return (tf.gather(self.states, indices),
tf.gather(self.actions, indices),
tf.gather(self.rewards, indices),
tf.gather(self.next_states, indices))
エージェント
class DoubleDQN():
def __init__(self,
state_size,
action_size,
learning_rate,
gamma,
memory_maxlen,
batch_size,
epsilon,
epsilon_decay,
epsilon_min,
update_target_frequency,
node
):
# Initialize
self.state_size = state_size
self.action_size = action_size
self.batch_size = batch_size
self.step = 0 # 学習回数
self.node = node
self.history = 0
# 方策に関するハイパーパラメタ
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
self.action_size = action_size
self.epsilon_min = epsilon_min
# 学習に関するハイパーパラメタ
self.learning_rate = learning_rate
self.gamma = gamma
# リプレイメモリ
self.replay_memory = ExperienceBuffer(memory_maxlen)
# ネットワークの構築
self.main_network : tf.keras.Model = self._build_network()
self.target_network : tf.keras.Model = self._build_network()
# target_networkの重みは、mainネットワークの重みで初期化
self.target_network.set_weights(self.main_network.get_weights())
# ターゲットネットの更新頻度
self.update_target_frequency = update_target_frequency
# 送信後はこの関数を実行して、結果を学習する
def observe(self, state, action, reward, next_state):
# 1, 観測データをリプレイメモリに追加
self.replay_memory.add(state, action, reward, next_state)
# 2, 観測データで学習
# Info: リプレイメモリが十分なとき(バッチサイズより大きい)のとき、学習する
if self.replay_memory.length() > self.batch_size:
self._train_agent()
# 送信前に、この関数を実行して、行動を取得する
def get_action(self, next_state):
# ステップの更新
self.step += 1
# リプレイメモリが不十分なときは、ランダムに行動を選択
if self.replay_memory.length() < self.batch_size:
return np.random.randint(self.action_size)
# ---Policy---
# Info: ここでは、e-greedy policyを実装
if self.epsilon > np.random.rand():
# ランダムに行動を選択
action = np.random.randint(self.action_size)
else:
# メインネットワークから行動価値を計算
qvalues = self.main_network.predict(next_state)[0]
add_qinfo(self.node, self.step, qvalues, self.history.history['loss'][-1])
action = np.argmax(qvalues)
# epsilonの更新式
if self.epsilon > self.epsilon_min:
self.epsilon = (1 / (self.step / self.epsilon_decay + 1))
# ------------
return action
# 学習を行う関数
def _train_agent(self):
# バッチの取得
states, actions, rewards, next_states = self.replay_memory.sample(self.batch_size)
# Info: これはバッチ全てに対して行動価値を計算している
# 1, メインネットワークからQ値を取得
main_q_values = self.main_network(states)
# 2, メインネットワークに次状態を入力して、目標値(target)の計算で利用するの行動を取得
main_next_q_values = self.main_network(next_states)
main_next_actions = tf.one_hot(tf.argmax(main_next_q_values, axis=1), depth=self.action_size)
# 3, ターゲットネットワークに次状態を入力して、目標値(target)の計算で利用するQ値を取得
target_next_q_values = self.target_network(next_states)
# 4, ターゲットネットワークから教師データを作成する
# Creating a range for the row indices
# 4-1, 目標値の計算
targets = tf.multiply(1 - actions, main_q_values) + rewards + self.gamma * (actions * tf.reshape(tf.norm(main_next_actions * target_next_q_values, axis=1), [-1, 1]))
# loss = main_q_values - targets
# print("DEBUG: Qmain(s_t+1, a): ", main_next_q_values[0])
# print("DEBUG: argmax a Qmain(s_t+1, a): ", main_next_actions[0])
# print("DEBUG: Qtar(s_t+1, a): ", target_next_q_values[0])
# print("DEBUG: Qtar(s_t+1, argmax a): ", (main_next_actions * target_next_q_values)[0])
# print("DEBUG: reward s_t: ", rewards[0])
# print("DEBUG: formated Qtar(s_t+1, argmax a): ", (actions * tf.reshape(tf.norm(main_next_actions * target_next_q_values, axis=1), [-1, 1]))[0])
# print("DEBUG: Q(s_t, a): ", main_q_values[0])
# print("DEBUG: tar: ", targets[0])
# print("DEBUG: loss: ", loss[0])
# 5, targetを教師データとして、メインネットワークを学習
# label: targets
# output: main_q_values
self.history = self.main_network.fit(
states,
targets,
batch_size=self.batch_size,
epochs=1,
verbose=0,
)
# 5, ターゲットネットワークの重みは、定期的に更新
if self.step % self.update_target_frequency == 0:
self.target_network.set_weights(self.main_network.get_weights())
# メインネットワークのビルド
def _build_network(self):
input_layer = tf.keras.Input(shape=(self.state_size,))
dense1 = tf.keras.layers.Dense(32, activation='relu')(input_layer)
dense2 = tf.keras.layers.Dense(32, activation='relu')(dense1)
output_layer = tf.keras.layers.Dense(self.action_size, activation='softmax')(dense2)
model = tf.keras.Model(inputs=input_layer, outputs=output_layer)
model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(learning_rate=self.learning_rate))
model.summary()
return model
使い方
def DQN_test():
env = TekitoEnv()
env.reset()
state = np.reshape(np.random.uniform(-1, 1, state_size), [-1, 1])
for episode in range(episode_count):
action = dqn_agent.get_action(state)
next_state, reward, _, _ = env.step(action)
next_state = np.reshape(next_state, [-1, 1])
dqn_agent.observe(state, action, reward, next_state)
state = next_state
まとめ
実装に自身がないので、間違いがあったらコメントで指摘をよろしくお願いします♥
参考
- どこからみてもメンダコ 『DQNの進化史 ②Double-DQN, Dueling-network, Noisy-network』https://horomary.hatenablog.com/entry/2021/02/06/013412
- @ysk0832 『Double DQN と Fixed Target Q-Network を 4 step で理解する』https://qiita.com/ysk0832/items/1c7508cd43daa95af005
- tcom, Qiita 『【深層強化学習】Double Deep Q Network(DDQN)』 https://www.tcom242242.net/entry/ai-2/強化学習/【深層強化学習】double-q-network/
- Chris Yoon, Towards Data Science『Double Deep Q Networks』https://towardsdatascience.com/double-deep-q-networks-905dd8325412
- Hado van Hasselt, Arthur Guez, David Silver『Deep Reinforcement Learning with Double Q-learning』 https://arxiv.org/abs/1509.06461
Discussion