🤖

強化学習をPytorchで実装 DQN(Deep Q Network)編

2024/11/29に公開

趣旨

DQN(Deep Q Network)のアルゴリズム説明・コード実装をします.
アルゴリズム(コードもですが)に関しては,他の方がわかりやすく説明しているため,ざっくりと.
実装に関しては,CartPole(カートポール)問題に取り組みます.

強化学習

強化学習では,エージェント(学習者)が環境の中で行動し,その結果に応じて報酬を受け取ります.エージェントは受け取る報酬の総和が最大となるような方策(最適方策)を獲得できるように学習していきます.

迷路の出口を目指して進むエージェントについて考えてみます.まず,エージェントは迷路という環境(environment)から状態(state)を受け取ります.ここでの状態は,エージェント自身の位置です.次に,エージェントは状態からどのような行動(action)を取るか決定します.行動はエージェントの方策により決定されます.そして,その行動に対して報酬(reward)を環境から受け取ります.この報酬からエージェントは,取った行動が良かったか悪かったかを学習します.その後もエージェントは,何度も迷路に挑戦しながら,どの行動がゴールまでの最短ルートにつながるのかを学習していきます.

Q-Learning(Q学習)

Q学習は,強化学習と呼ばれる機械学習の1種です.
Q学習は,エージェントが各状態でどの行動を取るべきかをQ値(行動価値)で学習していく手法です.Q値は,現在の状態に対してある行動を取った場合の価値で,簡単に言うと,状態と行動の組み合わせの点数といったイメージです.Q値はQテーブルに格納されています.テーブルのサイズは,(状態の数)×(状態1つのパターン数)×(行動の数)で,この後に登場するCartPole問題では状態が,カートの位置,速度,ポールの角度,角速度4つです.ここで,何もしなければ各状態の連続値を状態として取得することになります(例:0.1, 0.4, 0.2, 0.3).しかし,これでは状態1つのパターン数が無限に存在することになります.この問題を解決するためによく採られる方法は,連続値を離散化することです.行動は,左か右の2択であるため,行動数は2です.各状態を4つに離散化すれば,状態1つのパターン数が4となり,テーブルサイズは,4×4×2=32です.

そして,下の数式がQ値の更新式です.

Q'(S_t, A_t) = Q(S_t, A_t) + α[R_t + γ\max_{a} Q(S_{t+1}, a) - Q(S_t, A_t)]

S_t : ステップtでの状態
A_t : ステップtでの行動
R_t : ステップtでの報酬
Q(S_t, A_t) : 状態S_t,行動A_tでのQ値
α :学習率
γ :割引率

行動を取った後に,期待される報酬と現在のQ値の差分(R_t + γ\max_{a} Q(S_{t+1}, a) - Q(S_t, A_t)TD誤差と呼ぶ)を使って現在のQ(s, a)を調整しています.\max_a Q(S_{t+1}, a)は状態S_{t+1}で最適な行動を取ったときのQ値(最適な行動とは,Q値が最大となる行動).また,

学習率が0に近いほど,Q値の更新幅が小さく,1に近いほど,更新幅が大きいです.割引率が0に近いほど,次の状態での行動価値をあまり考慮せず,1に近いほど,考慮します.

\max_a Q(S_{t+1}, a)の箇所を推定方策,ターゲット方策と呼んだりします.これをQ(S_{t+1}, A_{t+1})としたものがSARSAと呼ばれる手法です.

ε-greedy法

実際,どのように行動を選択すればよいかについて述べていきます.

学習初期は,どの行動が良いかわからないので,いろいろな行動を試す必要があります.学習が十分でないのにも関わらず,Q値が最大となる行動だけを取っていると,それ以外の行動に対してのQ値が一向に更新されないです.ある程度Q値が定まってきたら,Q値が高い行動を多めに取っていきます.Q値が低い行動を何度取っても意味がないからです.これを手法として確立したものがε-greedy法になります.

エージェントの方策が\piの場合,挙動方策\pi'はε-greedy法により,

\pi'(a|S_t) = \begin{cases} \mathrm{argmax}_a Q_π(S_t, a) &\text{if } \mathrm{1-\varepsilon} \\\\ \text{random action} &\text{if } \mathrm{\varepsilon} \end{cases}

1-εの確率でQ値が最大の行動を取り,εの確率でランダムな行動を取る.これで探索利用の両立が可能になる.

DQN(Deep Q Network)

Q学習では,すべての状態と行動の組み合わせをQテーブルに格納します.したがって,状態数と行動数が膨大であった場合,必要なメモリ空間が非常に大きくなります.また,すべての組み合わせを探索することは現実的に不可能です.

ここでDQN(Deep Q Network)の登場です.DQNではQテーブルを使用せず,Q関数をニューラルネットワークで近似します(ニューラルネットワークの説明は割愛).ニューラルネットワークに状態を入力した後,その状態でのすべての行動のQ値を計算します(下の画像では,画像→畳み込み層→全結合層→Q値の流れ).

更新式はQ学習と同じです(※実際にこれを使用するわけではない).

Q'(S_t, A_t) = Q(S_t, A_t) + α[R_t + γ\max_{a} Q(S_{t+1}, a) - Q(S_t, A_t)]

ただ,ニューラルネットワークを使用するからには,損失関数が必要になります.
Q学習にも同じことが言えますが,学習前のQ値と学習後のQ値が同じであれば,狙い通りの学習ができていると言えます(DQNの場合は不可能のため,あくまで理想).すなわち,R_t + γ\max_{a} Q(S_{t+1}, a) - Q(S_t, A_t)が小さければ良いということです.これをそのまま損失関数に使用します.

Loss = R_t + γ\max_{a} Q(S_{t+1}, a) - Q(S_t, A_t)

しかし,ただQテーブルをニューラルネットワークに置き換えて,損失関数を与えただけでうまくいくほど甘くありません.加えて以下の技術がDQNで採用されています.

Experience Replay(経験再生)

DQNでは,エージェントが経験したデータ(状態,行動,報酬,次の状態)をメモリに保存し,学習時にそのメモリからランダムにデータを取り出して学習します.これを経験再生と言います.

Q学習では1回の経験をすぐに学習に使用し,この経験がその後の学習で使用されることはありません.しかし,それだとデータに偏りが出たり,古いデータがすぐに忘れられたりします.得た経験を1度の学習で使い捨てるのはもったいないです(サンプル効率が向上するとか言うことが多いです).

少し話は逸れますが,強化学習のアルゴリズムには,オンポリシーオフポリシーの(多分)2種類があります.オフポリシーは挙動方策と推定(ターゲット)方策が異なり,オンポリシーは同じです.Q学習は,挙動方策がQ_π(S_t, A_t),推定方策が\max_{a} Q_π(S_{t+1}, a)で,オフポリシーだとわかります.したがって,Q学習の1種であるDQNもオフポリシーであり,挙動方策と推定方策が異なっていても大丈夫です.以上より,DQNは過去の方策によって得られた経験を現在の方策更新にも使用でき,経験再生の適用が可能だと言えます.

Target Network(ターゲットネットワーク)

DQNではQ値を更新する際に(推定方策として)ターゲットネットワークと呼ばれる固定された別のネットワークを使用します.挙動方策に使用されるネットワークが定期的にターゲットネットワークにコピーされます.学習の度にコロコロと推定方策が変わると,学習が安定しないです.

下の数式が,ターゲットネットワークを新たに導入した損失関数.

Loss = (R_t + γ\max_{a} Q_{TargetNet}(S_{t+1}, a) - Q_{Net}(S_t, A_t)) ^ 2

Q_{TargetNet}: ターゲットネットワークで出力されるQ値
Q_{Net}: メインネットワークで出力されるQ値

Pytorchで実装

CartPole問題

CartPole問題は、台車の上に垂直に立てられた棒を倒さないように、台車を左右に動かしてバランスを取る制御問題です。強化学習や制御理論の基礎として使われ、エージェントは適切な動作を学び、ポールを長時間立て続けることを目指します。

行動空間
・0:カートを左に押す
・1:カートを右に押す

観測空間

No. 観測 最小値 最大値
0 カートの位置 -4.8 4.8
1 カートの速度 -Inf Inf
2 ポールの角度 -0.418 rad (-24°) 0.418 rad (24°)
3 ポールの角速度 -Inf Inf

・この表はあくまでも観測であり,カートの位置が(-2.4, 2.4)もしくは,ポールの位置(-0.2095, 2.095)の範囲を超えるとゲームオーバーとなる.

報酬
・終了ステップを含む各ステップで+1の報酬
・500ステップ目まで報酬が与えられる(500ステップまでポールを維持したら終了)

レコーディング-2024-09-29-221108.gif

ハイパーパラメータ

hyperparameters.py
TOTAL_TIMESTEPS = 50000 # 総ステップ数
MAX_STEP = 500 # 1エピソードでの最大ステップ数
BUFFER_SIZE = 1000000 #バッファサイズ
BATCH_SIZE = 32 # バッチサイズ
LEARNING_RATE = 0.0001 # 学習率
DISCOUNT_RATE = 0.99 # 割引率
STATE_SIZE = 4 # 状態数
ACTION_SIZE = 2 # 行動数
TARGET_UPDATE_STEPS = 1000 # ターゲットネットワークの更新ステップ頻度
LOG_STEPS = 5000 # ログ出力のステップ頻度

経験再生バッファ

ExperienceReplayBuffer.py
from typing import NamedTuple, Tuple, Union
from collections import deque
import copy
import random
import numpy as np
import torch

class TorchTensor(NamedTuple):
    state: torch.Tensor
    action: torch.Tensor
    reward: torch.Tensor
    next_state: torch.Tensor
    done: torch.Tensor
    
class ExperienceReplayBuffer:
    def __init__(
        self,
        buffer_size: int=10000,
        batch_size: int=64,
    ):
        self.buffer = deque(maxlen=buffer_size)
        self.batch_size = batch_size
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    def add(
        self,
        state: Union[np.ndarray, Tuple],
        action: int,
        reward: Union[int, float],
        next_state: Union[np.ndarray, Tuple],
        done: bool,
        ) -> None:
        self.buffer.append((state, action, reward, next_state, done))
        
    def get(self) -> TorchTensor:
        data = random.sample(self.buffer, self.batch_size)
        
        batch_data = (
            np.stack([x[0] for x in data]).astype(np.float32), # state
            np.array([x[1] for x in data]).astype(np.int32), # action
            np.array([x[2] for x in data]).astype(np.float32), # reward
            np.stack([x[3] for x in data]).astype(np.float32), # next_state
            np.array([x[4] for x in data]).astype(np.int32), # done
        )
        
        return TorchTensor(*tuple(map(self.to_torch, batch_data)))
        
    def to_torch(self, array: np.ndarray) -> torch.Tensor:
        return torch.tensor(array, dtype=torch.float32, device=self.device)   
コード説明

TorchTensorは,typing.NamedTupleを継承したクラスです.この後のAgentクラスで状態,行動,報酬等をドットアクセスで取得可能となり,可読性の向上が見込めます.また,すべてのコードに共通していることとして,型アノテーションが付いていますが,必ずその型の値を入力とする必要はないです.

example_01.py
replay = ExperienceReplayBuffer(state_size, action_size)
data = replay.get()
state = data.state # 状態をドットアクセスで取得

ExperienceReplayBufferクラスには,データの追加をするaddメソッド,データを取り出すgetメソッド,配列をnp.ndarrayからtorch.Tensorに変換するto_torchメソッドがあります.

addメソッドでは,エージェントが取得した(状態,行動,報酬,次の状態,終了条件を満たしたか)をバッファに追加します.今回はバッファにcollections.dequeを使用し,引数でmaxlenを指定すると,dequeの最大長を制限できます.すでにdequeが満杯の状態でデータを追加すると,先頭の要素が捨てられることになります.

getメソッドでは,まずバッファからバッチサイズ分のデータをランダムで取り出します.次に,各要素をnumpy配列に変換してバッチデータとしますが,状態と次の状態は要素が1つではないためnp.stackを使用して新たな軸で結合しています.しかし,このままニューラルネットワークに入力した場合,torch.Tensor型ではなくnp.ndarray型であるため,エラーを吐きます.したがって,作成したバッチデータをto_torchメソッドを使用することで変換しています.

:::details

ニューラルネットワーク

Q_Network.py
from torch import nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self, state_size: int, action_size: int):
        super(Net, self).__init__()
        self.layer1 = nn.Linear(state_size, 64)
        self.layer2 = nn.Linear(64, 64)
        self.layer3 = nn.Linear(64, action_size)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = self.layer3(x)
        
        return x
コード説明

これは,pytorchを使用した一般的な深層学習のコードです.今回は,難易度が低いCartPole問題のため,中間層は1つです.すべて全結合層で,活性化関数はReluを使用しています.中間層を追加するなり,ユニット数を変えるなりしてみてください.

エージェント

Agent.py
from torch import optim

class Agent:
    def __init__(self, state_size: int, action_size: int):
        self.state_size = state_size
        self.action_size = action_size

        self.lr = LEARNING_RATE
        self.gamma = DISCOUNT_RATE
        self.buffer_size = BUFFER_SIZE
        self.batch_size = BATCH_SIZE
        self.target_update = TARGET_UPDATE_STEPS

        self.epsilon_start = 1.0
        self.epsilon_end = 0.1
        self.epsilon_decay = (self.epsilon_start - self.epsilon_end) / TOTAL_TIMESTEPS
        self.epsilon = self.epsilon_start

        self.replay = ExperienceReplayBuffer(self.buffer_size, self.batch_size)
        self.data = None

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.original_qnet = Net(self.state_size, self.action_size).to(self.device)
        self.target_qnet = Net(self.state_size, self.action_size).to(self.device)
        self.sync_net()

        self.optimizer = optim.Adam(self.original_qnet.parameters(), self.lr)
        
    def get_action(self, state) -> int:
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        else:
            state = torch.tensor(state[np.newaxis, :].astype(np.float32), device=self.device)
            q_c = self.original_qnet(state)
            return q_c.detach().argmax().item()
            
    def update(self) -> None:
        if len(self.replay.buffer) < self.batch_size:
            return
            
        self.data = self.replay.get()
        q_c = self.original_qnet(self.data.state)
        q = q_c[np.arange(self.batch_size), self.data.action.cpu().numpy()]

        with torch.no_grad()
            next_q_c = self.target_qnet(self.data.next_state)
            next_q = next_q_c.max(1)[0]
            next_q.detach()
            target = self.data.reward + (1 - self.data.done) * self.gamma *  next_q

        loss_function = nn.MSELoss()
        loss = loss_function(q, target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
    def add_experience(
        self,
        state: np.ndarray,
        action: int,
        reward: Union[int, float],
        next_state: np.ndarray,
        done: bool,
        ) -> None:
        self.replay.add(state, action, reward, next_state, done)
        
    def sync_net(self) -> None:
        self.target_qnet = copy.deepcopy(self.original_qnet)
        
    def set_epsilon(self) -> None:
        self.epsilon -= self.epsilon_decay
        
    def save_model(self) -> None:
        torch.save(self.original_qnet.state_dict(), 'model.pth')
コード説明

Agentクラスには,行動を取得するget_actionメソッド,ネットワークを更新するupdateメソッド,経験を追加するadd_experienceメソッド,挙動方策に使用されるネットワークをターゲットネットワークにコピーするsync_netメソッド,ε-greedy法のεを更新するset_epsilonメソッド,モデルを保存するsave_modelメソッドがあります.

initメソッドで,ハイパーパラメータや2つのネットワークの初期化を行います.他には,ExperienceReplayBufferを呼び出したり,最適化手法を設定したりします.torch.optim.Adamの引数がself.original_qnet.parameters()だけである理由は,self.target_qnetはただのコピーだからです.

get_actionメソッドでは,エージェントの状態から行動を決定します.ε-greedy法を使用しているため,1-εの確率でQ値が最大となる行動,εの確率でランダムな行動を取ります.Q値が最大となる行動を取得する場合は,状態を二次元配列にしてからtorch.Tensorに変換.q_c.detach()では,Q値をネットワークから切り離している(DQNの場合,選択した行動を直接評価するわけではないため).

updateメソッドでは,バッチデータから損失関数を計算しネットワークを更新します.バッファサイズがバッチサイズより小さい場合は更新しません.まず,バッチデータをバッファから取得します.その後は,self.original_qnetでQ_{Net}(S_t, a)を取得し,その中からQ_{Net}(S_t, A_t)のみを取り出す.次は,self.target_qnetでQ_{TargetNet}(S_{t+1}, a)を取得し,\max_{a} Q_{TargetNet}(S_{t+1}, a)を取り出す.ここでもnext_q.detach()が出現していますが,先ほど述べた通り,target_qnetはあくまでコピーであり,このネットワークを直接更新することはないため,ネットワークから切り離します.出揃ったので,R_t + γ\max_{a} Q_{TargetNet}(S_{t+1}, a)を算出してから,Q_{Net}(S_t, A_t)との平均二乗誤差を算出する.pytorchはtensorflowとは異なり,勾配が蓄積されるためself.optimizer.zero_grad()で勾配を初期化します(RNNなどの時系列系との兼ね合いのためだった気が).loss.backward()で逆伝播,self.optimizer.step()で実際に更新しています.

下の実行関数で直接ExperienceReplayBuffer().add()としてもよいが,agentのメソッドだけで進行した方がわかりやすいため,Agentクラスにadd_experienceメソッドを作成しました.

sync_netメソッドは,ただcopy.deepcopyしているだけです.

set_epsilonメソッドは,εを更新しているだけです.

save_modelメソッドは,学習したモデルを保存しているだけです.

実行関数

main.py
import gymnasium as gym
from logging import getLogger, basicConfig, INFO

basicConfig(level=INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = getLogger(__name__)

def main():
    env = gym.make('CartPole-v1')
    agent = Agent(STATE_SIZE, ACTION_SIZE)
    timesteps = 0
    while timesteps < TOTAL_TIMESTEPS:
        done = False
        state, _ = env.reset()
        steps = 0
        while steps < MAX_STEP:
            action = agent.get_action(state)
            next_state, reward, done, _, _ = env.step(action)
            agent.add_experience(state, action, reward, next_state, done)
            agent.update()
            state = next_state

            timesteps += 1
            steps += 1

            agent.set_epsilon()

            if done:
                break
            if timesteps % agent.target_update == 0:
                agent.sync_net()
            
            if timesteps % LOG_STEPS == 0:
                logger.info("time_steps: %d", timesteps)
    agent.save_model()
    env.close()
    env = gym.make('CartPole-v1', render_mode="human")
    done = False
    state, _  = env.reset()
    step = 0
    agent.epsilon = 0.0
    while step < MAX_STEP:
        env.render()
        action = agent.get_action(state)
        next_state, reward, done, _, _ = env.step(action)
        state = next_state

        if done:
            break
        
            
if __name__ == '__main__':
    main()
コード説明

最初に学習し,次に評価という流れです.

まとめ

DQNは深層強化学習で最も有名な手法であるため,他の強化学習アルゴリズムを使用するにしてもDQN理解は必須だなと感じます.

DQNを発展させたアルゴリズムについてもこのコードに少々変更を加えれば実装可能であるため,余裕があれば取り組んでみたいです.

参考文献

  1. Volodymyr Mnih, et al., “Human-level control through deep reinforcement learning”, Nature, Vol. 518, pp. 529-533, 2015.

Discussion