🤖

OpenAI Gymを使った強化学習の応用へ 〜パート3 マス目の世界を探索する〜

に公開

*この記事はQiita記事の再投稿となります。

こんにちは!
株式会社アイディオットでデータサイエンティストをしています、秋田と申します。
このシリーズは、強化学習のフレームワークを用いた最適化問題への応用を目的に、強化学習についてPythonライブラリの使用方法の観点から学ぼうというものになります。
前回はGymを使って、Grid Worldの環境を作成しました。
今回は、前回のGrid Worldを探索してゴールに向かうエージェントを作成してみましょう!

復習

前回作ったGrid World環境の GridWorldEnv クラスについて復習しましょう。
まずはコード全体を確認してみます。

# ライブラリのインポート
import gym
from gym import spaces
from matplotlib import colors
import matplotlib.pyplot as plt
import numpy as np
class GridWorldEnv(gym.Env):
    def __init__(self):
        # クラス継承のおまじない
        super(GridWorldEnv, self).__init__()

        # Action Spaceの定義
        self.action_space = spaces.Discrete(4)
        self.action_desc = """
        移動方向

        0: 上に進む
        1: 下に進む
        2: 左に進む
        3: 右に進む
        """

        # Observation Spaceの定義
        self.observation_space = spaces.Box(
            low=np.array([-10000, -10000, -10000, -10000]),
            high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
        )
        self.observation_desc = """
        報酬の種類

        黒(穴): -10000
        濃い紫: -3
        青: -2
        緑青: -1
        黄: 0
        明るい緑: 1
        緑: 2
        濃い緑: 3
        """

        # 壁と報酬の色分け
        self.cmap = colors.ListedColormap(
            [
                "#000000", "#440154", "#3b528b", "#21918c",
                "#fde725", "#aadc32", "#5ec962", "#2fb47c"
            ]
        )
        self.bounds = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]
        self.norm = colors.BoundaryNorm(self.bounds, self.cmap.N)

    def reset(self) -> None:
        # Grid Worldの定義
        self.grid = np.array([
            [
                0, 0, 0, 0, -10000, 1, 1, 0, -1, -2, 0, 1, 0, -10000, -10000,
                2, 3, 2, 1, 0, 0, -1, 0, -10000, 1, 0, 0, 1, 2, 3
            ],
            [
                0, -10000, -10000, 0, -1, 0, 1, -10000, -2, -3, 0, 1, 1, 0, 0,
                1, 2, -10000, -10000, 0, -1, 0, 1, 1, 0, 0, -10000, 2, 2, 1
            ],
            [
                0, -1, -1, 0, 0, 1, 2, 2, 1, 0, 0, -1, -1, 0, 0,
                1, -10000, -2, -3, -2, -1, 0, 0, -10000, 1, 1, 2, 3, 2, 1
            ],
            [
                1, 1, 2, 2, 0, -10000, -10000, -1, -2, -1, 0, 0, 1, 1, 0,
                -1, -2, -3, -10000, 0, 0, 1, 1, 2, 2, -10000, 0, -1, 0, 1
            ],
            [
                2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0,
                1, 1, 0, -1, -1, 0, 0, -10000, 1, 2, 2, 1, 0, -1, 0
            ],
            [
                2, -10000, -10000, 0, -1, -2, 0, 1, 2, 3, 2, 1, 0, -1, -1,
                0, 1, 1, -10000, 2, 3, 3, 2, 1, 0, 0, -1, -2, -10000, 0
            ],
            [
                1, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0, 1, 1,
                0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -10000, 1, 0
            ],
            [
                0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -1, -2, -10000,
                1, 2, 2, 1, 0, 0, -1, -2, -10000, 2, 2, 1, 0, -1, 0
            ],
            [
                0, 0, 0, 0, 0, 1, 1, -10000, 1, 1, 0, -1, -2, -2, -1,
                0, 1, 1, 2, 3, 2, 1, -10000, 0, 0, -1, -2, -3, -10000, 1
            ],
            [
                1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, -10000, 2, 3,
                2, 1, 0, -1, -10000, 0, 0, 1, 1, 2, 2, -10000, 1, 0, -1
            ],
            [
                1, -10000, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -1, 0,
                1, 2, 3, 2, 1, 0, -10000, 0, -1, -2, -1, 0, 1, 1, 2
            ],
            [
                2, 3, 3, 2, 1, 0, -1, -2, -10000, 1, 2, 1, 0, -1, -2,
                -3, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0, -10000, 1, 2, 3
            ],
            [
                1, 1, 0, -1, -2, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0,
                1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3
            ],
            [
                0, -10000, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 0, 1, 2,
                2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
            ],
            [
                0, 0, 0, 1, 2, 3, 3, 2, 1, 0, -1, -1, 0, 1, 2,
                3, 3, 2, 1, 0, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0
            ],
            [
                1, 2, -10000, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2, 1,
                0, -1, -1, 0, 1, 2, 2, 1, 0, 0, -1, -2, -10000, 1, 2
            ],
            [
                2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2,
                1, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -3, 0
            ],
            [
                1, -10000, -10000, 1, 2, 3, 2, 1, 0, -1, 0, 1, 2, 2, 1,
                0, -1, -1, 0, 1, 2, 3, 3, 2, 1, 0, -10000, 1, 2, 3
            ],
            [
                0, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2, 2, 1,
                0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -10000, 1, 2
            ],
            [
                0, -10000, 1, 2, 3, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2,
                2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
            ]
        ], dtype=int)

        # 行数と列数を取得
        self.rows, self.cols = self.grid.shape

        # 現在地(スタート地点)を取得
        self.current_pos = [0, 0]

        # 報酬を0にする
        self.total = 0

    def render(self) -> None:
        # 描画
        fig, ax = plt.subplots(figsize=(15, 10), tight_layout=True)
        im = ax.imshow(self.grid, cmap=self.cmap, norm=self.norm)

        # 現在地
        ax.add_patch(
            plt.Rectangle(
                (self.current_pos[0] - 0.5, self.current_pos[1] - 0.5), 1, 1,
                edgecolor="red", linewidth=2, fill=False
            )
        )
        ax.text(
            self.current_pos[0], self.current_pos[1], "P",
            va="center", ha="center", fontsize=15, color="red", weight="bold"
        )

        # スタート地点:左上
        if (self.current_pos[0] != 0) or (self.current_pos[1] != 0):
            ax.add_patch(
                plt.Rectangle(
                    (0 - 0.5, 0 - 0.5), 1, 1, edgecolor="white",
                    linewidth=2, fill=False
                )
            )
            ax.text(
                0, 0, "S", va="center", ha="center",
                fontsize=15, color="white", weight="bold"
            )

        # ゴール地点:右下
        if (self.current_pos[0] != self.cols - 1) or\
            (self.current_pos[1] != self.rows - 1):
            ax.add_patch(
                plt.Rectangle(
                    (self.cols - 1 - 0.5, self.rows - 1 - 0.5), 1, 1,
                    edgecolor="orange", linewidth=2, fill=False
                )
            )
            ax.text(
                self.cols - 1, self.rows - 1, "G", va="center", ha="center",
                fontsize=15, color="orange", weight="bold"
            )

        # 軸オフ、レイアウト調整
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(
            "Grid World (Start: Top-Left, Goal: Bottom-Right)", fontsize=16
        )
        plt.show()

    def step(self, action: int) -> tuple[np.int64, bool, dict]:
        # 終了判定をFalseに
        done = False

        # 追加情報
        info = {}

        # 現在地をコピーしてactionを元に遷移
        tmp = self.current_pos.copy()
        if action == 0:
            if self.current_pos[1] - 1 >= 0:
                self.current_pos[1] -= 1
            else:
                info["error"] = "wall"
                done = True
        elif action == 1:
            if self.current_pos[1] + 1 <= self.rows - 1:
                self.current_pos[1] += 1
            else:
                info["error"] = "wall"
                done = True
        elif action == 2:
            if self.current_pos[0] - 1 >= 0:
                self.current_pos[0] -= 1
            else:
                info["error"] = "wall"
                done = True
        else:
            if self.current_pos[0] + 1 <= self.cols - 1:
                self.current_pos[0] += 1
            else:
                info["error"] = "wall"
                done = True

        # 現在の報酬
        reward = self.grid[self.current_pos[1]][self.current_pos[0]]

        # 報酬が-10000であれば強制終了、そうでなければ合計報酬に加算
        if reward == -10000:
            if info != {}:
                info["error"] = "hole"
            done = True
        else:
            self.total += reward

        # 遷移前の地点を穴にする
        self.grid[tmp[1]][tmp[0]] = -10000

        # ゴールに着いたら終了判定をTrueに
        if self.current_pos == [self.cols - 1, self.rows - 1]:
            info["clear"] = self.total
            done = True

        return reward, done, info

インスタンスを生成して初期化し、 .step() メソッドで順に動くことで報酬を得ながら先に進むことが出来るようになりました。
また、 .render() メソッドを使うことで可視化を行うことも出来るようになりました。

# 環境のインスタンスを用意
env = GridWorldEnv()

# 初期化
env.reset()

# ゴールまでのルートを用意
routes = [
    1, 1, 1, 3, 3, 3, 3, 0, 3, 3, 3, 3, 3, 3, 1, 3, 1, 2, 2, 1, 1, 1, 1, 3, 1,
    1, 1, 3, 1, 2, 2, 2, 2, 2, 1, 1, 3, 3, 3, 1, 3, 3, 3, 1, 1, 1, 3, 0, 0, 0,
    0, 3, 0, 3, 1, 3, 3, 3, 3, 1, 3, 3, 3, 1, 2, 2, 1, 1, 3, 0, 3, 1, 3, 0, 0,
    3, 1, 1, 1, 3, 3, 3, 3, 3
]

# 報酬と終了判定と追加情報を取得
for route in routes:
    reward, done, info = env.step(route)
print(f"報酬:\t{reward}")
print(f"終了判定:\t{done}")
print(f"追加情報:\t{info}")

# 描画
env.render()

ここまでが復習になります!

行動の選択方法について

現状の環境では、行動を起こすことで状態を変化させることは出来ます。
一方で、行動選択の意思決定をするための条件が全くありません。
では、その条件とは一体何でしょうか?
まずは今の状況をイメージ出来るように整理してみましょう。
行動選択について前回やったこととしては、主に2種類あります。
1つ目が、何も情報のない真っ暗な2次元空間を、根拠になるものも無しに突き進む方法です。
これは、言ってしまえば毎回0から3の4つの整数値をランダムに選んで進むということになります。
2つ目は、神のお告げを聞く方法です。
神、すなわち我々Grid Worldの全貌を知るものが、予めルートを指定してその順番に行動を選択させることですね。
ただ、これを読んでくださっている方々は強化学習を勉強しに来ている人々なので、それらの方法が適切ではないことくらい承知かと思われます。
ではここに、サードオプションとして、周囲の報酬の情報を確認して進むという方法を考えてみましょう。
次の行動によって貰える報酬がわかる、つまり現在地からの四方向の報酬がそれぞれ何かがわかるのであればどちらに進むべきかの意思決定が出来そうですね。
では、 GridWorldEnv クラスに四方位の報酬を取得する .observe() メソッドというものを追加実装してみましょう!

class GridWorldEnv(gym.Env):
    def __init__(self):
        # クラス継承のおまじない
        super(GridWorldEnv, self).__init__()

        # Action Spaceの定義
        self.action_space = spaces.Discrete(4)
        self.action_desc = """
        移動方向

        0: 上に進む
        1: 下に進む
        2: 左に進む
        3: 右に進む
        """

        # Observation Spaceの定義
        self.observation_space = spaces.Box(
            low=np.array([-10000, -10000, -10000, -10000]),
            high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
        )
        self.observation_desc = """
        報酬の種類

        黒(穴): -10000
        濃い紫: -3
        青: -2
        緑青: -1
        黄: 0
        明るい緑: 1
        緑: 2
        濃い緑: 3
        """

        # 壁と報酬の色分け
        self.cmap = colors.ListedColormap(
            [
                "#000000", "#440154", "#3b528b", "#21918c",
                "#fde725", "#aadc32", "#5ec962", "#2fb47c"
            ]
        )
        self.bounds = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]
        self.norm = colors.BoundaryNorm(self.bounds, self.cmap.N)

    def reset(self) -> None:
        # Grid Worldの定義
        self.grid = np.array([
            [
                0, 0, 0, 0, -10000, 1, 1, 0, -1, -2, 0, 1, 0, -10000, -10000,
                2, 3, 2, 1, 0, 0, -1, 0, -10000, 1, 0, 0, 1, 2, 3
            ],
            [
                0, -10000, -10000, 0, -1, 0, 1, -10000, -2, -3, 0, 1, 1, 0, 0,
                1, 2, -10000, -10000, 0, -1, 0, 1, 1, 0, 0, -10000, 2, 2, 1
            ],
            [
                0, -1, -1, 0, 0, 1, 2, 2, 1, 0, 0, -1, -1, 0, 0,
                1, -10000, -2, -3, -2, -1, 0, 0, -10000, 1, 1, 2, 3, 2, 1
            ],
            [
                1, 1, 2, 2, 0, -10000, -10000, -1, -2, -1, 0, 0, 1, 1, 0,
                -1, -2, -3, -10000, 0, 0, 1, 1, 2, 2, -10000, 0, -1, 0, 1
            ],
            [
                2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0,
                1, 1, 0, -1, -1, 0, 0, -10000, 1, 2, 2, 1, 0, -1, 0
            ],
            [
                2, -10000, -10000, 0, -1, -2, 0, 1, 2, 3, 2, 1, 0, -1, -1,
                0, 1, 1, -10000, 2, 3, 3, 2, 1, 0, 0, -1, -2, -10000, 0
            ],
            [
                1, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0, 1, 1,
                0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -10000, 1, 0
            ],
            [
                0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -1, -2, -10000,
                1, 2, 2, 1, 0, 0, -1, -2, -10000, 2, 2, 1, 0, -1, 0
            ],
            [
                0, 0, 0, 0, 0, 1, 1, -10000, 1, 1, 0, -1, -2, -2, -1,
                0, 1, 1, 2, 3, 2, 1, -10000, 0, 0, -1, -2, -3, -10000, 1
            ],
            [
                1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, -10000, 2, 3,
                2, 1, 0, -1, -10000, 0, 0, 1, 1, 2, 2, -10000, 1, 0, -1
            ],
            [
                1, -10000, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -1, 0,
                1, 2, 3, 2, 1, 0, -10000, 0, -1, -2, -1, 0, 1, 1, 2
            ],
            [
                2, 3, 3, 2, 1, 0, -1, -2, -10000, 1, 2, 1, 0, -1, -2,
                -3, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0, -10000, 1, 2, 3
            ],
            [
                1, 1, 0, -1, -2, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0,
                1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3
            ],
            [
                0, -10000, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 0, 1, 2,
                2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
            ],
            [
                0, 0, 0, 1, 2, 3, 3, 2, 1, 0, -1, -1, 0, 1, 2,
                3, 3, 2, 1, 0, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0
            ],
            [
                1, 2, -10000, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2, 1,
                0, -1, -1, 0, 1, 2, 2, 1, 0, 0, -1, -2, -10000, 1, 2
            ],
            [
                2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2,
                1, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -3, 0
            ],
            [
                1, -10000, -10000, 1, 2, 3, 2, 1, 0, -1, 0, 1, 2, 2, 1,
                0, -1, -1, 0, 1, 2, 3, 3, 2, 1, 0, -10000, 1, 2, 3
            ],
            [
                0, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2, 2, 1,
                0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -10000, 1, 2
            ],
            [
                0, -10000, 1, 2, 3, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2,
                2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
            ]
        ], dtype=int)

        # 行数と列数を取得
        self.rows, self.cols = self.grid.shape

        # 現在地(スタート地点)を取得
        self.current_pos = [0, 0]

        # 報酬を0にする
        self.total = 0

    def render(self) -> None:
        # 描画
        fig, ax = plt.subplots(figsize=(15, 10), tight_layout=True)
        im = ax.imshow(self.grid, cmap=self.cmap, norm=self.norm)

        # 現在地
        ax.add_patch(
            plt.Rectangle(
                (self.current_pos[0] - 0.5, self.current_pos[1] - 0.5), 1, 1,
                edgecolor="red", linewidth=2, fill=False
            )
        )
        ax.text(
            self.current_pos[0], self.current_pos[1], "P",
            va="center", ha="center", fontsize=15, color="red", weight="bold"
        )

        # スタート地点:左上
        if (self.current_pos[0] != 0) or (self.current_pos[1] != 0):
            ax.add_patch(
                plt.Rectangle(
                    (0 - 0.5, 0 - 0.5), 1, 1, edgecolor="white",
                    linewidth=2, fill=False
                )
            )
            ax.text(
                0, 0, "S", va="center", ha="center",
                fontsize=15, color="white", weight="bold"
            )

        # ゴール地点:右下
        if (self.current_pos[0] != self.cols - 1) or\
            (self.current_pos[1] != self.rows - 1):
            ax.add_patch(
                plt.Rectangle(
                    (self.cols - 1 - 0.5, self.rows - 1 - 0.5), 1, 1,
                    edgecolor="orange", linewidth=2, fill=False
                )
            )
            ax.text(
                self.cols - 1, self.rows - 1, "G", va="center", ha="center",
                fontsize=15, color="orange", weight="bold"
            )

        # 軸オフ、レイアウト調整
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(
            "Grid World (Start: Top-Left, Goal: Bottom-Right)", fontsize=16
        )
        plt.show()

    def step(self, action: int) -> tuple[np.int64, bool, dict]:
        # 終了判定をFalseに
        done = False

        # 追加情報
        info = {}

        # 現在地をコピーしてactionを元に遷移
        tmp = self.current_pos.copy()
        if action == 0:
            if self.current_pos[1] - 1 >= 0:
                self.current_pos[1] -= 1
            else:
                info["error"] = "wall"
                done = True
        elif action == 1:
            if self.current_pos[1] + 1 <= self.rows - 1:
                self.current_pos[1] += 1
            else:
                info["error"] = "wall"
                done = True
        elif action == 2:
            if self.current_pos[0] - 1 >= 0:
                self.current_pos[0] -= 1
            else:
                info["error"] = "wall"
                done = True
        else:
            if self.current_pos[0] + 1 <= self.cols - 1:
                self.current_pos[0] += 1
            else:
                info["error"] = "wall"
                done = True

        # 現在の報酬
        reward = self.grid[self.current_pos[1]][self.current_pos[0]]

        # 報酬が-10000であれば強制終了、そうでなければ合計報酬に加算
        if reward == -10000:
            if info != {}:
                info["error"] = "hole"
            done = True
        else:
            self.total += reward

        # 遷移前の地点を穴にする
        self.grid[tmp[1]][tmp[0]] = -10000

        # ゴールに着いたら終了判定をTrueに
        if self.current_pos == [self.cols - 1, self.rows - 1]:
            info["clear"] = self.total
            done = True

        return reward, done, info
    
    def observe(self) -> dict:
        # 四方位の報酬を取得
        if self.current_pos[1] - 1 >= 0:
            up = self.grid[self.current_pos[1] - 1][self.current_pos[0]]
        else:
            up = -10000
        if self.current_pos[1] + 1 <= self.rows - 1:
            down = self.grid[self.current_pos[1] + 1][self.current_pos[0]]
        else:
            down = -10000
        if self.current_pos[0] - 1 >= 0:
            left = self.grid[self.current_pos[1]][self.current_pos[0] - 1]
        else:
            left = -10000
        if self.current_pos[0] + 1 <= self.cols - 1:
            right = self.grid[self.current_pos[1]][self.current_pos[0] + 1]
        else:
            right = -10000

        # 四方位の報酬を辞書型で格納
        obs = {0: up, 1: down, 2: left, 3: right}

        return obs

.observe() メソッドは、現在地から四方向に1マス進んだときの報酬(壁にぶつかる場合は穴として-10000)をそれぞれ辞書型で持ち、それを返すものになっています。
これは .step() メソッドを使って状態を遷移させる手前に、行動を選択する上で参照するものなので、 .reset() メソッドの最後及び、 .step() メソッドの最後にて呼び出して返り値として渡す形が好ましいですね。
そうなると、前々回で __init__() の中で初期化を行なったように設計し直すのが良いでしょう。

class GridWorldEnv(gym.Env):
    def __init__(self):
        # クラス継承のおまじない
        super(GridWorldEnv, self).__init__()

        # Action Spaceの定義
        self.action_space = spaces.Discrete(4)
        self.action_desc = """
        移動方向

        0: 上に進む
        1: 下に進む
        2: 左に進む
        3: 右に進む
        """

        # Observation Spaceの定義
        self.observation_space = spaces.Box(
            low=np.array([-10000, -10000, -10000, -10000]),
            high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
        )
        self.observation_desc = """
        報酬の種類

        黒(穴): -10000
        濃い紫: -3
        青: -2
        緑青: -1
        黄: 0
        明るい緑: 1
        緑: 2
        濃い緑: 3
        """

        # 壁と報酬の色分け
        self.cmap = colors.ListedColormap(
            [
                "#000000", "#440154", "#3b528b", "#21918c",
                "#fde725", "#aadc32", "#5ec962", "#2fb47c"
            ]
        )
        self.bounds = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]
        self.norm = colors.BoundaryNorm(self.bounds, self.cmap.N)

        # 初期化
        self.init_obs = self.reset()

    def reset(self) -> dict:
        # Grid Worldの定義
        self.grid = np.array([
            [
                0, 0, 0, 0, -10000, 1, 1, 0, -1, -2, 0, 1, 0, -10000, -10000,
                2, 3, 2, 1, 0, 0, -1, 0, -10000, 1, 0, 0, 1, 2, 3
            ],
            [
                0, -10000, -10000, 0, -1, 0, 1, -10000, -2, -3, 0, 1, 1, 0, 0,
                1, 2, -10000, -10000, 0, -1, 0, 1, 1, 0, 0, -10000, 2, 2, 1
            ],
            [
                0, -1, -1, 0, 0, 1, 2, 2, 1, 0, 0, -1, -1, 0, 0,
                1, -10000, -2, -3, -2, -1, 0, 0, -10000, 1, 1, 2, 3, 2, 1
            ],
            [
                1, 1, 2, 2, 0, -10000, -10000, -1, -2, -1, 0, 0, 1, 1, 0,
                -1, -2, -3, -10000, 0, 0, 1, 1, 2, 2, -10000, 0, -1, 0, 1
            ],
            [
                2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0,
                1, 1, 0, -1, -1, 0, 0, -10000, 1, 2, 2, 1, 0, -1, 0
            ],
            [
                2, -10000, -10000, 0, -1, -2, 0, 1, 2, 3, 2, 1, 0, -1, -1,
                0, 1, 1, -10000, 2, 3, 3, 2, 1, 0, 0, -1, -2, -10000, 0
            ],
            [
                1, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0, 1, 1,
                0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -10000, 1, 0
            ],
            [
                0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -1, -2, -10000,
                1, 2, 2, 1, 0, 0, -1, -2, -10000, 2, 2, 1, 0, -1, 0
            ],
            [
                0, 0, 0, 0, 0, 1, 1, -10000, 1, 1, 0, -1, -2, -2, -1,
                0, 1, 1, 2, 3, 2, 1, -10000, 0, 0, -1, -2, -3, -10000, 1
            ],
            [
                1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, -10000, 2, 3,
                2, 1, 0, -1, -10000, 0, 0, 1, 1, 2, 2, -10000, 1, 0, -1
            ],
            [
                1, -10000, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -1, 0,
                1, 2, 3, 2, 1, 0, -10000, 0, -1, -2, -1, 0, 1, 1, 2
            ],
            [
                2, 3, 3, 2, 1, 0, -1, -2, -10000, 1, 2, 1, 0, -1, -2,
                -3, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0, -10000, 1, 2, 3
            ],
            [
                1, 1, 0, -1, -2, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0,
                1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3
            ],
            [
                0, -10000, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 0, 1, 2,
                2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
            ],
            [
                0, 0, 0, 1, 2, 3, 3, 2, 1, 0, -1, -1, 0, 1, 2,
                3, 3, 2, 1, 0, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0
            ],
            [
                1, 2, -10000, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2, 1,
                0, -1, -1, 0, 1, 2, 2, 1, 0, 0, -1, -2, -10000, 1, 2
            ],
            [
                2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2,
                1, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -3, 0
            ],
            [
                1, -10000, -10000, 1, 2, 3, 2, 1, 0, -1, 0, 1, 2, 2, 1,
                0, -1, -1, 0, 1, 2, 3, 3, 2, 1, 0, -10000, 1, 2, 3
            ],
            [
                0, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2, 2, 1,
                0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -10000, 1, 2
            ],
            [
                0, -10000, 1, 2, 3, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2,
                2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
            ]
        ], dtype=int)

        # 行数と列数を取得
        self.rows, self.cols = self.grid.shape

        # 現在地(スタート地点)を取得
        self.current_pos = [0, 0]

        # 報酬を0にする
        self.total = 0

        # 四方位の報酬を取得
        obs = self.observe()

        return obs

    def render(self) -> None:
        # 描画
        fig, ax = plt.subplots(figsize=(15, 10), tight_layout=True)
        im = ax.imshow(self.grid, cmap=self.cmap, norm=self.norm)

        # 現在地
        ax.add_patch(
            plt.Rectangle(
                (self.current_pos[0] - 0.5, self.current_pos[1] - 0.5), 1, 1,
                edgecolor="red", linewidth=2, fill=False
            )
        )
        ax.text(
            self.current_pos[0], self.current_pos[1], "P",
            va="center", ha="center", fontsize=15, color="red", weight="bold"
        )

        # スタート地点:左上
        if (self.current_pos[0] != 0) or (self.current_pos[1] != 0):
            ax.add_patch(
                plt.Rectangle(
                    (0 - 0.5, 0 - 0.5), 1, 1, edgecolor="white",
                    linewidth=2, fill=False
                )
            )
            ax.text(
                0, 0, "S", va="center", ha="center",
                fontsize=15, color="white", weight="bold"
            )

        # ゴール地点:右下
        if (self.current_pos[0] != self.cols - 1) or\
            (self.current_pos[1] != self.rows - 1):
            ax.add_patch(
                plt.Rectangle(
                    (self.cols - 1 - 0.5, self.rows - 1 - 0.5), 1, 1,
                    edgecolor="orange", linewidth=2, fill=False
                )
            )
            ax.text(
                self.cols - 1, self.rows - 1, "G", va="center", ha="center",
                fontsize=15, color="orange", weight="bold"
            )

        # 軸オフ、レイアウト調整
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(
            "Grid World (Start: Top-Left, Goal: Bottom-Right)", fontsize=16
        )
        plt.show()

    def step(self, action: int) -> tuple[dict, np.int64, bool, dict]:
        # 終了判定をFalseに
        done = False

        # 追加情報
        info = {}

        # 現在地をコピーしてactionを元に遷移
        tmp = self.current_pos.copy()
        if action == 0:
            if self.current_pos[1] - 1 >= 0:
                self.current_pos[1] -= 1
            else:
                info["error"] = "wall"
                done = True
        elif action == 1:
            if self.current_pos[1] + 1 <= self.rows - 1:
                self.current_pos[1] += 1
            else:
                info["error"] = "wall"
                done = True
        elif action == 2:
            if self.current_pos[0] - 1 >= 0:
                self.current_pos[0] -= 1
            else:
                info["error"] = "wall"
                done = True
        else:
            if self.current_pos[0] + 1 <= self.cols - 1:
                self.current_pos[0] += 1
            else:
                info["error"] = "wall"
                done = True

        # 現在の報酬
        reward = self.grid[self.current_pos[1]][self.current_pos[0]]

        # 報酬が-10000であれば強制終了、そうでなければ合計報酬に加算
        if reward == -10000:
            if info != {}:
                info["error"] = "hole"
            done = True
        else:
            self.total += reward

        # 遷移前の地点を穴にする
        self.grid[tmp[1]][tmp[0]] = -10000

        # ゴールに着いたら終了判定をTrueに
        if self.current_pos == [self.cols - 1, self.rows - 1]:
            info["clear"] = self.total
            done = True

        # 四方位の報酬を取得
        obs = self.observe()

        return obs, reward, done, info
    
    def observe(self) -> dict:
        # 四方位の報酬を取得
        if self.current_pos[1] - 1 >= 0:
            up = self.grid[self.current_pos[1] - 1][self.current_pos[0]]
        else:
            up = -10000
        if self.current_pos[1] + 1 <= self.rows - 1:
            down = self.grid[self.current_pos[1] + 1][self.current_pos[0]]
        else:
            down = -10000
        if self.current_pos[0] - 1 >= 0:
            left = self.grid[self.current_pos[1]][self.current_pos[0] - 1]
        else:
            left = -10000
        if self.current_pos[0] + 1 <= self.cols - 1:
            right = self.grid[self.current_pos[1]][self.current_pos[0] + 1]
        else:
            right = -10000

        # 四方位の報酬を辞書型で格納
        obs = {0: up, 1: down, 2: left, 3: right}

        return obs

はい、これで完全な GridWorldEnv クラスが出来上がりました!
試しに見てみましょう。

# 環境のインスタンスを用意
env = GridWorldEnv()

# 描画
env.render()
# 初期の周囲の報酬を確認
print(f"obs:\t{env.init_obs}")

観測量がそれぞれ確認できました!

DQN

今回の強化学習モデルとして使用するのはDQNになります。
DQNそのものについての解説はここでは省略させていただきます(参考文献)。
DQNとは、Q学習とニューラルネットワークを組み合わせた手法であり、コアとなる技術は「経験再生」と「ターゲットネットワーク」の2つになります。

経験再生

Q学習では、エージェントが環境に対して行動を行うたびにデータは生成されます。
ある時間 t において得られる E_t = (S_t, A_t, R_t, S_{t+1}) を使ってQ関数を更新します。
ここで、 S_t は時刻 t における状態、 A_t は時刻 t における行動、 R_t は時刻 t における報酬を意味し、この E_t を経験データと呼ぶことにしましょう。
経験データは時間 t が進むに従って得られますが、経験データ間には強い相関があるため、教師あり学習のようにミニバッチ化をしてデータに偏りが無いように学習できません。
これを克服するのが経験再生です!
まずはエージェントが経験したデータ E_t = (S_t, A_t, R_t, S_{t+1}) を一度「バッファ」に保存します。
そして、Q関数を更新する際はそのバッファから経験データをランダムに取り出します。
これにより、経験データ間の相関が弱まるので、偏りの少ないデータが得られるということです。
さらに、経験データを繰り返し使うことが出来るため、データ効率が良くなります。
では、経験再生の仕組みを ReplayBuffer というクラスで実装してみましょう!

経験再生のバッファは無限にデータを格納することはできないので、最大サイズを超えたら古いデータから順に削除します。
よって、 .buffer という属性を作る上で、引数に buffer_size というものを用意し、collectionsの deque() 関数を使ってバッファを定義します。
また、他に引数として、バッチサイズを指定する batch_size と、今回PyTorchを使った実装を行うため device"cpu" をデフォルト値として使います。
ここまでをまずは実装します。

# ライブラリのインポート
from collections import deque

import torch
class ReplayBuffer:
    def __init__(self, buffer_size: int, batch_size: int, device: str="cpu"):
        # バッファを定義
        self.buffer = deque(maxlen=buffer_size)

        # 学習環境のパラメータを設定
        self.batch_size = batch_size
        self.device = torch.device(device)

続いてバッファの長さを取得する .__len__() メソッドを加えましょう。

class ReplayBuffer:
    def __init__(self, buffer_size: int, batch_size: int, device: str="cpu"):
        # バッファを定義
        self.buffer = deque(maxlen=buffer_size)

        # 学習環境のパラメータを設定
        self.batch_size = batch_size
        self.device = torch.device(device)

    def __len__(self) -> int:
        # バッファの長さを取得
        length = len(self.buffer)

        return length

そしてバッファに経験データ E_t を追加する .push() メソッドを作成します。

# ライブラリのインポート
from typing import Union
class ReplayBuffer:
    def __init__(self, buffer_size: int, batch_size: int, device: str="cpu"):
        # バッファを定義
        self.buffer = deque(maxlen=buffer_size)

        # 学習環境のパラメータを設定
        self.batch_size = batch_size
        self.device = torch.device(device)

    def __len__(self) -> int:
        # バッファの長さを取得
        length = len(self.buffer)

        return length

    def push(
        self, state: torch.Tensor, action: int, reward: Union[int, np.int64],
        next_state: list, done: bool
    ) -> None:
        # バッファに経験データを追加
        self.buffer.append((state, action, reward, next_state, done))

最後に、バッファの中から経験データをサンプリングしてバッチ化する .sample() メソッドを入れます。

# ライブラリのインポート
import random
class ReplayBuffer:
    def __init__(self, buffer_size: int, batch_size: int, device: str="cpu"):
        # バッファを定義
        self.buffer = deque(maxlen=buffer_size)

        # 学習環境のパラメータを設定
        self.batch_size = batch_size
        self.device = torch.device(device)

    def __len__(self) -> int:
        # バッファの長さを取得
        length = len(self.buffer)

        return length

    def push(
        self, state: torch.Tensor, action: int, reward: Union[int, np.int64],
        next_state: list, done: bool
    ) -> None:
        # バッファに経験データを追加
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self) -> tuple[
        torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
    ]:
        # ランダムなサンプルを用意してバッチ化
        batch = random.sample(self.buffer, self.batch_size)

        # 種類ごとにバッチ内データをTensor型に変換
        state = torch.FloatTensor(
            np.stack([x[0].to("cpu").detach().numpy().copy() for x in batch])
        ).to(self.device)
        action = torch.LongTensor(
            np.stack([x[1] for x in batch])
        ).unsqueeze(1).to(self.device)
        reward = torch.FloatTensor(
            np.stack([x[2] for x in batch])
        ).unsqueeze(1).to(self.device)
        next_state = torch.FloatTensor(
            np.stack([x[3] for x in batch])
        ).to(self.device)
        done = torch.FloatTensor(
            np.stack([x[4] for x in batch])
        ).unsqueeze(1).to(self.device)

        return state, action, reward, next_state, done

これで経験再生のクラスが作れました!

ターゲットネットワーク

Q学習では、 Q(S_t, A_t) の値がTDターゲット R_t + \gamma \max_a Q(S_{t+1}, a) となるようにQ関数を更新します。
しかし、TDターゲットの値はQ関数が更新されるたびに変動するため、教師あり学習のようにラベルが学習途中で変わらないということがありません。
ターゲットネットワークは、これを克服するためのTDターゲットを固定するテクニックになります!
これは、次のような仕組みでできています。
まずはQ関数を表すオリジナルのネットワーク「Original」を用意します。
それとは別にもう1つ同じ構造のネットワーク「Target」を用意します。
「Original」は通常のQ学習によって更新を行い、「Target」は定期的に「Original」の重みと同期するようにして、それ以外は重みパラメータを固定したままにします。
後は「Target」を使ってTDターゲットの値を計算すれば、TDターゲットの変動が抑えられるので、ニューラルネットワークの学習が安定することが期待されます。

まずはPyTorchでベースとなるニューラルネットワークを作りましょう!
PyTorchでのニューラルネットワークの作成方法については理解がある前提で進めます。
今回は3層の線形層と活性化関数は全てReLUを使ったモデルを構築します。

# ライブラリのインポート
from torch import nn
class DQN(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        # クラス継承のおまじない
        super(DQN, self).__init__()

        # 各種ネットワークの定義
        self.linear1 = nn.Linear(input_dim, 128)
        self.linear2 = nn.Linear(128, 128)
        self.linear3 = nn.Linear(128, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 順伝播
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.relu(x)
        x = self.linear3(x)

        return x

これを「Original」と「Target」に使用できるような DQNAgent クラスを作成しましょう。
必要なハイパーパラメータの種類は次の通りです。

  • input_dim
    • モデルの入力層の次元
    • 状態(観測量)の数、ただし他の情報を入れる可能性もある
  • output_dim
    • モデルの出力層の次元
    • 状態(観測量)の数
  • lr
    • 学習率
    • デフォルト値 5e-4
  • gamma
    • 割引率
    • 連続タスクにおいて収益を無限大に発散させるのを防ぐためのパラメータ
    • デフォルト値 0.98
  • eps
    • モデルを介した意思決定をしない確率(ϵ-greedy)
    • デフォルト値 1.0
  • eps_decay
    • ϵの値を徐々に下げていく割合
    • デフォルト値 0.995
  • eps_min
    • ϵの最小値
    • デフォルト値 0.01
  • buffer_size
    • バッファサイズ
    • デフォルト値 10000
  • batch_size
    • バッチサイズ
    • デフォルト値 32
  • device
    • PyTorchの計算環境デバイス
    • デフォルト値 "cpu"
class DQNAgent:
    def __init__(
        self, input_dim: int, output_dim: int, lr: float=5e-4,
        gamma: float=0.98, eps: float=1.0, eps_decay: float=0.995,
        eps_min: float=0.01, buffer_size: int=10000, batch_size: int=32,
        device: str="cpu"
    ):
        # ハイパーパラメータ等を属性として持つ
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lr = lr
        self.gamma = gamma
        self.eps = eps
        self.eps_decay = eps_decay
        self.eps_min = eps_min
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.device = torch.device(device)

続いて DQN クラスを使って2つのモデルを作成しましょう。

class DQNAgent:
    def __init__(
        self, input_dim: int, output_dim: int, lr: float=5e-4,
        gamma: float=0.98, eps: float=1.0, eps_decay: float=0.995,
        eps_min: float=0.01, buffer_size: int=10000, batch_size: int=32,
        device: str="cpu"
    ):
        # ハイパーパラメータ等を属性として持つ
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lr = lr
        self.gamma = gamma
        self.eps = eps
        self.eps_decay = eps_decay
        self.eps_min = eps_min
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.device = torch.device(device)

        # 2つのモデルを構築
        self.model = DQN(self.input_dim, self.output_dim).to(self.device)
        self.target_model = DQN(self.input_dim, self.output_dim).to(self.device)

一旦 __init__() についてはここで止めておきましょう。
まだ改修が必要なので、その点は留意していただきます。

.get_action() メソッド

ここでは、 __init__() の中で作ったモデルを用いて行動の選択を行えるようにしましょう!
基本的には .modelDQN クラスインスタンスに予測計算をしてもらうのですが、最初のうちはモデルが学習されていないですし、必ずしもモデルが選択した結果が正しいとは言えないため .eps の確率で選択方法を変える必要があります。
今回は、 .eps の確率で周囲四方向の中で最も報酬の値が大きい箇所を選択するという方法を取りましょう。

class DQNAgent:
    def __init__(
        self, input_dim: int, output_dim: int, lr: float=5e-4,
        gamma: float=0.98, eps: float=1.0, eps_decay: float=0.995,
        eps_min: float=0.01, buffer_size: int=10000, batch_size: int=32,
        device: str="cpu"
    ):
        # ハイパーパラメータ等を属性として持つ
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lr = lr
        self.gamma = gamma
        self.eps = eps
        self.eps_decay = eps_decay
        self.eps_min = eps_min
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.device = torch.device(device)

        # 2つのモデルを構築
        self.model = DQN(self.input_dim, self.output_dim).to(self.device)
        self.target_model = DQN(self.input_dim, self.output_dim).to(self.device)

    def get_action(self, state: torch.Tensor) -> int:
        # 生成された乱数がϵの値より小さい場合は報酬が最大のものを選択する
        if np.random.rand() < self.eps:
            action = torch.argmax(state).item()
        else:
            qs = self.model(state)
            action = torch.argmax(qs).item()
        
        return action

これでエージェントが行動を選択するロジックができましたね!

.update() メソッド

今度は状態ではなく、モデル内部のパラメータを更新するための .update() メソッドを実装してみましょう。
ここは、普通のニューラルネットワークによる深層学習のトレーニングの部分で行われている処理に該当します。
学習に関わってくるので、 ReplayBuffer クラスも使いそうですね。
__init__() の改修を行います。
と言っても、ただ ReplayBuffer のインスタンスを用意して学習に必要な最適化器と損失関数を属性として定義するだけなので、あまり身構えなくても大丈夫です!
最適化器にはAdamを、損失関数にはMSEを使っておきましょうか。

# ライブラリのインポート
from torch import optim
class DQNAgent:
    def __init__(
        self, input_dim: int, output_dim: int, lr: float=5e-4,
        gamma: float=0.98, eps: float=1.0, eps_decay: float=0.995,
        eps_min: float=0.01, buffer_size: int=10000, batch_size: int=32,
        device: str="cpu"
    ):
        # ハイパーパラメータ等を属性として持つ
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lr = lr
        self.gamma = gamma
        self.eps = eps
        self.eps_decay = eps_decay
        self.eps_min = eps_min
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.device = torch.device(device)

        # 2つのモデルを構築
        self.model = DQN(self.input_dim, self.output_dim).to(self.device)
        self.target_model = DQN(self.input_dim, self.output_dim).to(self.device)

        # 経験再生のインスタンス
        self.replay_buffer = ReplayBuffer(
            self.buffer_size, self.batch_size, self.device
        )

        # 学習に必要な最適化器と損失関数を定義
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()

    def get_action(self, state: torch.Tensor) -> int:
        # 生成された乱数がϵの値より小さい場合は報酬が最大のものを選択する
        if np.random.rand() < self.eps:
            action = torch.argmax(state).item()
        else:
            qs = self.model(state)
            action = torch.argmax(qs).item()
        
        return action

まだ __init__() が完全ではないですがまた中断し、本題の .update() メソッドを実装します。
モデルをアップデートするので、バッファの中身も変わってきます。
.push() メソッドでバッファを更新させましょう。
また、その際にバッチサイズ分のデータがそもそも無ければ学習が出来ないので処理を抜けるようにしましょう。

class DQNAgent:
    def __init__(
        self, input_dim: int, output_dim: int, lr: float=5e-4,
        gamma: float=0.98, eps: float=1.0, eps_decay: float=0.995,
        eps_min: float=0.01, buffer_size: int=10000, batch_size: int=32,
        device: str="cpu"
    ):
        # ハイパーパラメータ等を属性として持つ
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lr = lr
        self.gamma = gamma
        self.eps = eps
        self.eps_decay = eps_decay
        self.eps_min = eps_min
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.device = torch.device(device)

        # 2つのモデルを構築
        self.model = DQN(self.input_dim, self.output_dim).to(self.device)
        self.target_model = DQN(self.input_dim, self.output_dim).to(self.device)

        # 経験再生のインスタンス
        self.replay_buffer = ReplayBuffer(
            self.buffer_size, self.batch_size, self.device
        )

        # 学習に必要な最適化器と損失関数を定義
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()

    def get_action(self, state: torch.Tensor) -> int:
        # 生成された乱数がϵの値より小さい場合は報酬が最大のものを選択する
        if np.random.rand() < self.eps:
            action = torch.argmax(state).item()
        else:
            qs = self.model(state)
            action = torch.argmax(qs).item()
        
        return action

    def update(
        self, state: torch.Tensor, action: int, reward: Union[int, np.int64],
        next_state: list, done: bool
    ) -> None:
        # バッファにデータを追加
        self.replay_buffer.push(state, action, reward, next_state, done)

        # 学習に必要なデータ数が足りていない場合はここで抜ける
        if self.replay_buffer.__len__() < self.batch_size:
            return

続いて、学習が出来るようであればバッファの中からデータを引っ張ってきて、Q関数の出力を見てみましょう。
Q関数は、 Q(S_t, A_t) というように時刻 t における状態 S_t と行動 A_t を用いて表されます。
一気に2つの変数を注ぎ込むことは出来ないので、段階的に作っていきます。
まず、状態 S_t をオリジナルのモデルに入れることで、 qs という途中の状態を作ります。
そして、行動 A_t を次のようにして組み込みます。

q = qs[np.arange(len(action)), action]

状態 S_t をモデルに入れた出力の中で、行動 A_t に応じた成分のみ抽出するようなイメージです。
これがQ関数 Q(S_t, A_t) になります。

class DQNAgent:
    def __init__(
        self, input_dim: int, output_dim: int, lr: float=5e-4,
        gamma: float=0.98, eps: float=1.0, eps_decay: float=0.995,
        eps_min: float=0.01, buffer_size: int=10000, batch_size: int=32,
        device: str="cpu"
    ):
        # ハイパーパラメータ等を属性として持つ
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lr = lr
        self.gamma = gamma
        self.eps = eps
        self.eps_decay = eps_decay
        self.eps_min = eps_min
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.device = torch.device(device)

        # 2つのモデルを構築
        self.model = DQN(self.input_dim, self.output_dim).to(self.device)
        self.target_model = DQN(self.input_dim, self.output_dim).to(self.device)

        # 経験再生のインスタンス
        self.replay_buffer = ReplayBuffer(
            self.buffer_size, self.batch_size, self.device
        )

        # 学習に必要な最適化器と損失関数を定義
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()

    def get_action(self, state: torch.Tensor) -> int:
        # 生成された乱数がϵの値より小さい場合は報酬が最大のものを選択する
        if np.random.rand() < self.eps:
            action = torch.argmax(state).item()
        else:
            qs = self.model(state)
            action = torch.argmax(qs).item()
        
        return action

    def update(
        self, state: torch.Tensor, action: int, reward: Union[int, np.int64],
        next_state: list, done: bool
    ) -> None:
        # バッファにデータを追加
        self.replay_buffer.push(state, action, reward, next_state, done)

        # 学習に必要なデータ数が足りていない場合はここで抜ける
        if self.replay_buffer.__len__() < self.batch_size:
            return

        # データをサンプリング
        state, action, reward, next_state, done = self.replay_buffer.sample()

        # Q関数の出力を生成
        qs = self.model(state)
        q = qs[np.arange(len(action)), action]

さらに、TDターゲット R_t + \gamma \max_a Q(S_{t+1}, a) も表してみましょう。
これは「Target」の方のモデルで作成します。
時刻 t+1 の状態 S_{t+1} をモデルに入れた結果を next_qs とします。
そして、各行(バッチ毎)の最大値を次のように取り出します。

next_q = next_qs.max(1)[0]

後は定義式に沿って実装したものを target として持っておくだけです。

class DQNAgent:
    def __init__(
        self, input_dim: int, output_dim: int, lr: float=5e-4,
        gamma: float=0.98, eps: float=1.0, eps_decay: float=0.995,
        eps_min: float=0.01, buffer_size: int=10000, batch_size: int=32,
        device: str="cpu"
    ):
        # ハイパーパラメータ等を属性として持つ
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lr = lr
        self.gamma = gamma
        self.eps = eps
        self.eps_decay = eps_decay
        self.eps_min = eps_min
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.device = torch.device(device)

        # 2つのモデルを構築
        self.model = DQN(self.input_dim, self.output_dim).to(self.device)
        self.target_model = DQN(self.input_dim, self.output_dim).to(self.device)

        # 経験再生のインスタンス
        self.replay_buffer = ReplayBuffer(
            self.buffer_size, self.batch_size, self.device
        )

        # 学習に必要な最適化器と損失関数を定義
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()

    def get_action(self, state: torch.Tensor) -> int:
        # 生成された乱数がϵの値より小さい場合は報酬が最大のものを選択する
        if np.random.rand() < self.eps:
            action = torch.argmax(state).item()
        else:
            qs = self.model(state)
            action = torch.argmax(qs).item()
        
        return action

    def update(
        self, state: torch.Tensor, action: int, reward: Union[int, np.int64],
        next_state: list, done: bool
    ) -> None:
        # バッファにデータを追加
        self.replay_buffer.push(state, action, reward, next_state, done)

        # 学習に必要なデータ数が足りていない場合はここで抜ける
        if self.replay_buffer.__len__() < self.batch_size:
            return

        # データをサンプリング
        state, action, reward, next_state, done = self.replay_buffer.sample()

        # Q関数の出力を生成
        qs = self.model(state)
        q = qs[np.arange(len(action)), action]

        # TDターゲットを生成
        with torch.no_grad():
            next_qs = self.target_model(next_state)
            next_q = next_qs.max(1)[0]
            target = reward + (1 - done) * self.gamma * next_q

        # 計算できるように型を揃える
        target = torch.tensor(target, dtype=torch.float32)

最後に、ニューラルネットワークでの学習にいつもやる損失計算と誤差逆伝播です!
ここは説明が要らないですね。

class DQNAgent:
    def __init__(
        self, input_dim: int, output_dim: int, lr: float=5e-4,
        gamma: float=0.98, eps: float=1.0, eps_decay: float=0.995,
        eps_min: float=0.01, buffer_size: int=10000, batch_size: int=32,
        device: str="cpu"
    ):
        # ハイパーパラメータ等を属性として持つ
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lr = lr
        self.gamma = gamma
        self.eps = eps
        self.eps_decay = eps_decay
        self.eps_min = eps_min
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.device = torch.device(device)

        # 2つのモデルを構築
        self.model = DQN(self.input_dim, self.output_dim).to(self.device)
        self.target_model = DQN(self.input_dim, self.output_dim).to(self.device)

        # 経験再生のインスタンス
        self.replay_buffer = ReplayBuffer(
            self.buffer_size, self.batch_size, self.device
        )

        # 学習に必要な最適化器と損失関数を定義
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()

    def get_action(self, state: torch.Tensor) -> int:
        # 生成された乱数がϵの値より小さい場合は報酬が最大のものを選択する
        if np.random.rand() < self.eps:
            action = torch.argmax(state).item()
        else:
            qs = self.model(state)
            action = torch.argmax(qs).item()
        
        return action

    def update(
        self, state: torch.Tensor, action: int, reward: Union[int, np.int64],
        next_state: list, done: bool
    ) -> None:
        # バッファにデータを追加
        self.replay_buffer.push(state, action, reward, next_state, done)

        # 学習に必要なデータ数が足りていない場合はここで抜ける
        if self.replay_buffer.__len__() < self.batch_size:
            return

        # データをサンプリング
        state, action, reward, next_state, done = self.replay_buffer.sample()

        # Q関数の出力を生成
        qs = self.model(state)
        q = qs[np.arange(len(action)), action]

        # TDターゲットを生成
        with torch.no_grad():
            next_qs = self.target_model(next_state)
            next_q = next_qs.max(1)[0]
            target = reward + (1 - done) * self.gamma * next_q

        # 計算できるように型を揃える
        target = torch.tensor(target, dtype=torch.float32)

        # 損失計算と誤差逆伝播
        loss = self.criterion(q, target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

これで学習もバッチリです!

.sync_net() メソッド

本当に最後に、「Target」を「Model」とシンクロさせるだけのメソッドを実装します。

class DQNAgent:
    def __init__(
        self, input_dim: int, output_dim: int, lr: float=5e-4,
        gamma: float=0.98, eps: float=1.0, eps_decay: float=0.995,
        eps_min: float=0.01, buffer_size: int=10000, batch_size: int=32,
        device: str="cpu"
    ):
        # ハイパーパラメータ等を属性として持つ
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lr = lr
        self.gamma = gamma
        self.eps = eps
        self.eps_decay = eps_decay
        self.eps_min = eps_min
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.device = torch.device(device)

        # 2つのモデルを構築
        self.model = DQN(self.input_dim, self.output_dim).to(self.device)
        self.target_model = DQN(self.input_dim, self.output_dim).to(self.device)

        # 経験再生のインスタンス
        self.replay_buffer = ReplayBuffer(
            self.buffer_size, self.batch_size, self.device
        )

        # 学習に必要な最適化器と損失関数を定義
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()

    def get_action(self, state: torch.Tensor) -> int:
        # 生成された乱数がϵの値より小さい場合は報酬が最大のものを選択する
        if np.random.rand() < self.eps:
            action = torch.argmax(state).item()
        else:
            qs = self.model(state)
            action = torch.argmax(qs).item()
        
        return action

    def update(
        self, state: torch.Tensor, action: int, reward: Union[int, np.int64],
        next_state: list, done: bool
    ) -> None:
        # バッファにデータを追加
        self.replay_buffer.push(state, action, reward, next_state, done)

        # 学習に必要なデータ数が足りていない場合はここで抜ける
        if self.replay_buffer.__len__() < self.batch_size:
            return

        # データをサンプリング
        state, action, reward, next_state, done = self.replay_buffer.sample()

        # Q関数の出力を生成
        qs = self.model(state)
        q = qs[np.arange(len(action)), action]

        # TDターゲットを生成
        with torch.no_grad():
            next_qs = self.target_model(next_state)
            next_q = next_qs.max(1)[0]
            target = reward + (1 - done) * self.gamma * next_q

        # 計算できるように型を揃える
        target = torch.tensor(target, dtype=torch.float32)

        # 損失計算と誤差逆伝播
        loss = self.criterion(q, target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def sync_net(self) -> None:
        # モデルを同期させる
        self.target_model.load_state_dict(self.model.state_dict())

このメソッドは初期化した状態でも行わなければいけないですね!

class DQNAgent:
    def __init__(
        self, input_dim: int, output_dim: int, lr: float=5e-4,
        gamma: float=0.98, eps: float=1.0, eps_decay: float=0.995,
        eps_min: float=0.01, buffer_size: int=10000, batch_size: int=32,
        device: str="cpu"
    ):
        # ハイパーパラメータ等を属性として持つ
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lr = lr
        self.gamma = gamma
        self.eps = eps
        self.eps_decay = eps_decay
        self.eps_min = eps_min
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.device = torch.device(device)

        # 2つのモデルを構築
        self.model = DQN(self.input_dim, self.output_dim).to(self.device)
        self.target_model = DQN(self.input_dim, self.output_dim).to(self.device)

        # 初期状態としてモデルを同期しておく
        self.sync_net()

        # 経験再生のインスタンス
        self.replay_buffer = ReplayBuffer(
            self.buffer_size, self.batch_size, self.device
        )

        # 学習に必要な最適化器と損失関数を定義
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()

    def get_action(self, state: torch.Tensor) -> int:
        # 生成された乱数がϵの値より小さい場合は報酬が最大のものを選択する
        if np.random.rand() < self.eps:
            action = torch.argmax(state).item()
        else:
            qs = self.model(state)
            action = torch.argmax(qs).item()
        
        return action

    def update(
        self, state: torch.Tensor, action: int, reward: Union[int, np.int64],
        next_state: list, done: bool
    ) -> None:
        # バッファにデータを追加
        self.replay_buffer.push(state, action, reward, next_state, done)

        # 学習に必要なデータ数が足りていない場合はここで抜ける
        if self.replay_buffer.__len__() < self.batch_size:
            return

        # データをサンプリング
        state, action, reward, next_state, done = self.replay_buffer.sample()

        # Q関数の出力を生成
        qs = self.model(state)
        q = qs[np.arange(len(action)), action]

        # TDターゲットを生成
        with torch.no_grad():
            next_qs = self.target_model(next_state)
            next_q = next_qs.max(1)[0]
            target = reward + (1 - done) * self.gamma * next_q

        # 計算できるように型を揃える
        target = torch.tensor(target, dtype=torch.float32)

        # 損失計算と誤差逆伝播
        loss = self.criterion(q, target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def sync_net(self) -> None:
        # モデルを同期させる
        self.target_model.load_state_dict(self.model.state_dict())

これで DQNAgent クラスの完成です!

現時点での総括

今回はエージェントの作成を行いました。
前回までとは違って難しかったポイントとしては、何をするかという理論は理解できても、可視化などでチェックを行いながら進めにくかったため、抽象的になってしまいました。
なので本当にこれで上手くいくのか不安に思うこともあるかもしれませんが、それは次回試してみましょう。
まずはここまでのコードを以下にまとめておきましょう!

# ライブラリのインポート
from collections import deque
import random
from typing import Union

import gym
from gym import spaces
from matplotlib import colors
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
# Grid Worldの環境クラス
class GridWorldEnv(gym.Env):
    def __init__(self):
        # クラス継承のおまじない
        super(GridWorldEnv, self).__init__()

        # Action Spaceの定義
        self.action_space = spaces.Discrete(4)
        self.action_desc = """
        移動方向

        0: 上に進む
        1: 下に進む
        2: 左に進む
        3: 右に進む
        """

        # Observation Spaceの定義
        self.observation_space = spaces.Box(
            low=np.array([-10000, -10000, -10000, -10000]),
            high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
        )
        self.observation_desc = """
        報酬の種類

        黒(穴): -10000
        濃い紫: -3
        青: -2
        緑青: -1
        黄: 0
        明るい緑: 1
        緑: 2
        濃い緑: 3
        """

        # 壁と報酬の色分け
        self.cmap = colors.ListedColormap(
            [
                "#000000", "#440154", "#3b528b", "#21918c",
                "#fde725", "#aadc32", "#5ec962", "#2fb47c"
            ]
        )
        self.bounds = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]
        self.norm = colors.BoundaryNorm(self.bounds, self.cmap.N)

        # 初期化
        self.init_obs = self.reset()

    def reset(self) -> dict:
        # Grid Worldの定義
        self.grid = np.array([
            [
                0, 0, 0, 0, -10000, 1, 1, 0, -1, -2, 0, 1, 0, -10000, -10000,
                2, 3, 2, 1, 0, 0, -1, 0, -10000, 1, 0, 0, 1, 2, 3
            ],
            [
                0, -10000, -10000, 0, -1, 0, 1, -10000, -2, -3, 0, 1, 1, 0, 0,
                1, 2, -10000, -10000, 0, -1, 0, 1, 1, 0, 0, -10000, 2, 2, 1
            ],
            [
                0, -1, -1, 0, 0, 1, 2, 2, 1, 0, 0, -1, -1, 0, 0,
                1, -10000, -2, -3, -2, -1, 0, 0, -10000, 1, 1, 2, 3, 2, 1
            ],
            [
                1, 1, 2, 2, 0, -10000, -10000, -1, -2, -1, 0, 0, 1, 1, 0,
                -1, -2, -3, -10000, 0, 0, 1, 1, 2, 2, -10000, 0, -1, 0, 1
            ],
            [
                2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0,
                1, 1, 0, -1, -1, 0, 0, -10000, 1, 2, 2, 1, 0, -1, 0
            ],
            [
                2, -10000, -10000, 0, -1, -2, 0, 1, 2, 3, 2, 1, 0, -1, -1,
                0, 1, 1, -10000, 2, 3, 3, 2, 1, 0, 0, -1, -2, -10000, 0
            ],
            [
                1, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0, 1, 1,
                0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -10000, 1, 0
            ],
            [
                0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -1, -2, -10000,
                1, 2, 2, 1, 0, 0, -1, -2, -10000, 2, 2, 1, 0, -1, 0
            ],
            [
                0, 0, 0, 0, 0, 1, 1, -10000, 1, 1, 0, -1, -2, -2, -1,
                0, 1, 1, 2, 3, 2, 1, -10000, 0, 0, -1, -2, -3, -10000, 1
            ],
            [
                1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, -10000, 2, 3,
                2, 1, 0, -1, -10000, 0, 0, 1, 1, 2, 2, -10000, 1, 0, -1
            ],
            [
                1, -10000, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -1, 0,
                1, 2, 3, 2, 1, 0, -10000, 0, -1, -2, -1, 0, 1, 1, 2
            ],
            [
                2, 3, 3, 2, 1, 0, -1, -2, -10000, 1, 2, 1, 0, -1, -2,
                -3, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0, -10000, 1, 2, 3
            ],
            [
                1, 1, 0, -1, -2, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0,
                1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3
            ],
            [
                0, -10000, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 0, 1, 2,
                2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
            ],
            [
                0, 0, 0, 1, 2, 3, 3, 2, 1, 0, -1, -1, 0, 1, 2,
                3, 3, 2, 1, 0, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0
            ],
            [
                1, 2, -10000, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2, 1,
                0, -1, -1, 0, 1, 2, 2, 1, 0, 0, -1, -2, -10000, 1, 2
            ],
            [
                2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2,
                1, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -3, 0
            ],
            [
                1, -10000, -10000, 1, 2, 3, 2, 1, 0, -1, 0, 1, 2, 2, 1,
                0, -1, -1, 0, 1, 2, 3, 3, 2, 1, 0, -10000, 1, 2, 3
            ],
            [
                0, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2, 2, 1,
                0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -10000, 1, 2
            ],
            [
                0, -10000, 1, 2, 3, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2,
                2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
            ]
        ], dtype=int)

        # 行数と列数を取得
        self.rows, self.cols = self.grid.shape

        # 現在地(スタート地点)を取得
        self.current_pos = [0, 0]

        # 報酬を0にする
        self.total = 0

        # 四方位の報酬を取得
        obs = self.observe()

        return obs

    def render(self) -> None:
        # 描画
        fig, ax = plt.subplots(figsize=(15, 10), tight_layout=True)
        im = ax.imshow(self.grid, cmap=self.cmap, norm=self.norm)

        # 現在地
        ax.add_patch(
            plt.Rectangle(
                (self.current_pos[0] - 0.5, self.current_pos[1] - 0.5), 1, 1,
                edgecolor="red", linewidth=2, fill=False
            )
        )
        ax.text(
            self.current_pos[0], self.current_pos[1], "P",
            va="center", ha="center", fontsize=15, color="red", weight="bold"
        )

        # スタート地点:左上
        if (self.current_pos[0] != 0) or (self.current_pos[1] != 0):
            ax.add_patch(
                plt.Rectangle(
                    (0 - 0.5, 0 - 0.5), 1, 1, edgecolor="white",
                    linewidth=2, fill=False
                )
            )
            ax.text(
                0, 0, "S", va="center", ha="center",
                fontsize=15, color="white", weight="bold"
            )

        # ゴール地点:右下
        if (self.current_pos[0] != self.cols - 1) or\
            (self.current_pos[1] != self.rows - 1):
            ax.add_patch(
                plt.Rectangle(
                    (self.cols - 1 - 0.5, self.rows - 1 - 0.5), 1, 1,
                    edgecolor="orange", linewidth=2, fill=False
                )
            )
            ax.text(
                self.cols - 1, self.rows - 1, "G", va="center", ha="center",
                fontsize=15, color="orange", weight="bold"
            )

        # 軸オフ、レイアウト調整
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(
            "Grid World (Start: Top-Left, Goal: Bottom-Right)", fontsize=16
        )
        plt.show()

    def step(self, action: int) -> tuple[dict, np.int64, bool, dict]:
        # 終了判定をFalseに
        done = False

        # 追加情報
        info = {}

        # 現在地をコピーしてactionを元に遷移
        tmp = self.current_pos.copy()
        if action == 0:
            if self.current_pos[1] - 1 >= 0:
                self.current_pos[1] -= 1
            else:
                info["error"] = "wall"
                done = True
        elif action == 1:
            if self.current_pos[1] + 1 <= self.rows - 1:
                self.current_pos[1] += 1
            else:
                info["error"] = "wall"
                done = True
        elif action == 2:
            if self.current_pos[0] - 1 >= 0:
                self.current_pos[0] -= 1
            else:
                info["error"] = "wall"
                done = True
        else:
            if self.current_pos[0] + 1 <= self.cols - 1:
                self.current_pos[0] += 1
            else:
                info["error"] = "wall"
                done = True

        # 現在の報酬
        reward = self.grid[self.current_pos[1]][self.current_pos[0]]

        # 報酬が-10000であれば強制終了、そうでなければ合計報酬に加算
        if reward == -10000:
            if info != {}:
                info["error"] = "hole"
            done = True
        else:
            self.total += reward

        # 遷移前の地点を穴にする
        self.grid[tmp[1]][tmp[0]] = -10000

        # ゴールに着いたら終了判定をTrueに
        if self.current_pos == [self.cols - 1, self.rows - 1]:
            info["clear"] = self.total
            done = True

        # 四方位の報酬を取得
        obs = self.observe()

        return obs, reward, done, info
    
    def observe(self) -> dict:
        # 四方位の報酬を取得
        if self.current_pos[1] - 1 >= 0:
            up = self.grid[self.current_pos[1] - 1][self.current_pos[0]]
        else:
            up = -10000
        if self.current_pos[1] + 1 <= self.rows - 1:
            down = self.grid[self.current_pos[1] + 1][self.current_pos[0]]
        else:
            down = -10000
        if self.current_pos[0] - 1 >= 0:
            left = self.grid[self.current_pos[1]][self.current_pos[0] - 1]
        else:
            left = -10000
        if self.current_pos[0] + 1 <= self.cols - 1:
            right = self.grid[self.current_pos[1]][self.current_pos[0] + 1]
        else:
            right = -10000

        # 四方位の報酬を辞書型で格納
        obs = {0: up, 1: down, 2: left, 3: right}

        return obs


# 経験再生のクラス
class ReplayBuffer:
    def __init__(self, buffer_size: int, batch_size: int, device: str="cpu"):
        # バッファを定義
        self.buffer = deque(maxlen=buffer_size)

        # 学習環境のパラメータを設定
        self.batch_size = batch_size
        self.device = torch.device(device)

    def __len__(self) -> int:
        # バッファの長さを取得
        length = len(self.buffer)

        return length

    def push(
        self, state: torch.Tensor, action: int, reward: Union[int, np.int64],
        next_state: list, done: bool
    ) -> None:
        # バッファに経験データを追加
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self) -> tuple[
        torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
    ]:
        # ランダムなサンプルを用意してバッチ化
        batch = random.sample(self.buffer, self.batch_size)

        # 種類ごとにバッチ内データをTensor型に変換
        state = torch.FloatTensor(
            np.stack([x[0].to("cpu").detach().numpy().copy() for x in batch])
        ).to(self.device)
        action = torch.LongTensor(
            np.stack([x[1] for x in batch])
        ).unsqueeze(1).to(self.device)
        reward = torch.FloatTensor(
            np.stack([x[2] for x in batch])
        ).unsqueeze(1).to(self.device)
        next_state = torch.FloatTensor(
            np.stack([x[3] for x in batch])
        ).to(self.device)
        done = torch.FloatTensor(
            np.stack([x[4] for x in batch])
        ).unsqueeze(1).to(self.device)

        return state, action, reward, next_state, done


# DQNモデルのクラス
class DQN(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        # クラス継承のおまじない
        super(DQN, self).__init__()

        # 各種ネットワークの定義
        self.linear1 = nn.Linear(input_dim, 128)
        self.linear2 = nn.Linear(128, 128)
        self.linear3 = nn.Linear(128, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 順伝播
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.relu(x)
        x = self.linear3(x)

        return x


# DQNのエージェントクラス
class DQNAgent:
    def __init__(
        self, input_dim: int, output_dim: int, lr: float=5e-4,
        gamma: float=0.98, eps: float=1.0, eps_decay: float=0.995,
        eps_min: float=0.01, buffer_size: int=10000, batch_size: int=32,
        device: str="cpu"
    ):
        # ハイパーパラメータ等を属性として持つ
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lr = lr
        self.gamma = gamma
        self.eps = eps
        self.eps_decay = eps_decay
        self.eps_min = eps_min
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.device = torch.device(device)

        # 2つのモデルを構築
        self.model = DQN(self.input_dim, self.output_dim).to(self.device)
        self.target_model = DQN(self.input_dim, self.output_dim).to(self.device)

        # 初期状態としてモデルを同期しておく
        self.sync_net()

        # 経験再生のインスタンス
        self.replay_buffer = ReplayBuffer(
            self.buffer_size, self.batch_size, self.device
        )

        # 学習に必要な最適化器と損失関数を定義
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()

    def get_action(self, state: torch.Tensor) -> int:
        # 生成された乱数がϵの値より小さい場合は報酬が最大のものを選択する
        if np.random.rand() < self.eps:
            action = torch.argmax(state).item()
        else:
            qs = self.model(state)
            action = torch.argmax(qs).item()
        
        return action

    def update(
        self, state: torch.Tensor, action: int, reward: Union[int, np.int64],
        next_state: list, done: bool
    ) -> None:
        # バッファにデータを追加
        self.replay_buffer.push(state, action, reward, next_state, done)

        # 学習に必要なデータ数が足りていない場合はここで抜ける
        if self.replay_buffer.__len__() < self.batch_size:
            return

        # データをサンプリング
        state, action, reward, next_state, done = self.replay_buffer.sample()

        # Q関数の出力を生成
        qs = self.model(state)
        q = qs[np.arange(len(action)), action]

        # TDターゲットを生成
        with torch.no_grad():
            next_qs = self.target_model(next_state)
            next_q = next_qs.max(1)[0]
            target = reward + (1 - done) * self.gamma * next_q

        # 計算できるように型を揃える
        target = torch.tensor(target, dtype=torch.float32)

        # 損失計算と誤差逆伝播
        loss = self.criterion(q, target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def sync_net(self) -> None:
        # モデルを同期させる
        self.target_model.load_state_dict(self.model.state_dict())

随分と大きなプログラムになってきましたね!

次回予告

次回は、前回作ったGrid World環境で今回作ったDQNエージェントを学習させ、ゴールを目指してどこまで行けるか試してみましょう。
ようやく深層学習のフィールドにやって来ましたね!

Discussion