OpenAI Gymを使った強化学習の応用へ 〜パート2 マス目の世界を作る〜
*この記事はQiita記事の再投稿となります。
こんにちは!
株式会社アイディオットでデータサイエンティストをしています、秋田と申します。
このシリーズは、強化学習のフレームワークを用いた最適化問題への応用を目的に、強化学習についてPythonライブラリの使用方法の観点から学ぼうというものになります。
前回はGymの基本的な使い方を学び、簡単な環境を自作しました。
今回は、Grid Worldと呼ばれるマス目の世界で、上下左右の四方向に動いてスタートからゴールに向かって進み、道中で獲得できる報酬を最大化する問題の環境を作成してみましょう!
Gymの役目
前回はGymを使って環境(ジャンケン)を作成しました。
しかし、中には次のように疑問に思った人もいると思います。
「Gymを使う必要ってあった?」
実際に自作した MyEnv
クラスを見ても、 .action_space
や .observation_space
を呼び出す必要は特に無いですし、その他のメソッドを見てもGymに関連するような処理は一切行われていません。
実際グゥの音も出ないほどに(ジャンケンだけに)正論で、 gym.Env
クラスの継承をしなければいけないことはありません。
ただし、Gymは以前から強化学習フレームワークで使われており、エージェントを作成する様々な状況において慣れ親しんでいるという背景があります。
なので、このフレームワークを頭に叩き込んでおくことが重要であり、そのために本プロジェクトでは一貫してGymを使い続けます。
Grid Worldとは
(20, 30)のマス目の世界でエージェントがマスを移動し、各マスの報酬を加算しながらゴールを目指す問題を考えます。
今回の問題では、マスの種類は8つあるとしましょう。
まず、歩けるマスが次の7種類です。
- 濃い紫: -3
- 青: -2
- 緑青: -1
- 黄: 0
- 明るい緑: 1
- 緑: 2
- 濃い緑: 3
そして、他に穴があるとします。
- 黒: -10000
穴の報酬は-10000としますが、基本的にこれは報酬に加算せずに強制終了を意味します。
このGrid Worldでは、穴に落ちないように、また報酬を多く受け取れるように進みますが、一度歩いたマスは次の瞬間には穴になってしまいます。
また、外枠は壁で囲まれており、壁にぶつかるとその地点が穴になるため強制終了してしまいます。
それぞれのマスの状況を行列として定義しますが、ランダムに作ったものになるので適宜変更していただいても構いません。
__init__()
内にAction SpaceとObservation Spaceを定義する
では早速作ってみましょう!
前回同様にAction SpaceとObservation Spaceを定義します。
今回の行動は、上下左右の方向選択になります。
よって、次のようなコードを書けば良さそうですね。
action_space = spaces.Discrete(4)
また、今回は action_space
の行動についての説明を入れましょう。
それぞれの離散値は次のように対応します。
- 0: 上に進む
- 1: 下に進む
- 2: 左に進む
- 3: 右に進む
これを action_desc
として書いておきましょう。
まずはここまでを進めます。
# ライブラリのインポート
import gym
from gym import spaces
import numpy as np
class GridWorldEnv(gym.Env):
def __init__(self):
# クラス継承のおまじない
super(GridWorldEnv, self).__init__()
# Action Spaceの定義
self.action_space = spaces.Discrete(4)
self.action_desc = """
移動方向
0: 上に進む
1: 下に進む
2: 左に進む
3: 右に進む
"""
# 環境のインスタンスを用意
env = GridWorldEnv()
# Action Spaceの確認
print(env.action_space)
print(env.action_desc)
続いて、Observation Spaceを定義しましょう。
各行動に対して最低報酬が-10000、最高報酬が3となるので
observation_space = spaces.Box(
low=np.array([-10000, -10000, -10000, -10000]),
high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
)
と書けそうですね。
また、こちらも説明を入れましょう。
class GridWorldEnv(gym.Env):
def __init__(self):
# クラス継承のおまじない
super(GridWorldEnv, self).__init__()
# Action Spaceの定義
self.action_space = spaces.Discrete(4)
self.action_desc = """
移動方向
0: 上に進む
1: 下に進む
2: 左に進む
3: 右に進む
"""
# Observation Spaceの定義
self.observation_space = spaces.Box(
low=np.array([-10000, -10000, -10000, -10000]),
high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
)
self.observation_desc = """
報酬の種類
黒(穴): -10000
濃い紫: -3
青: -2
緑青: -1
黄: 0
明るい緑: 1
緑: 2
濃い緑: 3
"""
# 環境のインスタンスを用意
env = GridWorldEnv()
# Observation Spaceの確認
print(env.observation_space)
print(env.observation_desc)
良い感じですね!
.reset()
メソッドに初期状態の問題を埋め込む
まずは少し大変ですが、問題そのものをメソッドの中で定義してしまいます。
さらに、行数と列数を取得してスタート地点を決めてしまいましょう!
やっていることは難しくないと思います。
class GridWorldEnv(gym.Env):
def __init__(self):
# クラス継承のおまじない
super(GridWorldEnv, self).__init__()
# Action Spaceの定義
self.action_space = spaces.Discrete(4)
self.action_desc = """
移動方向
0: 上に進む
1: 下に進む
2: 左に進む
3: 右に進む
"""
# Observation Spaceの定義
self.observation_space = spaces.Box(
low=np.array([-10000, -10000, -10000, -10000]),
high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
)
self.observation_desc = """
報酬の種類
黒(穴): -10000
濃い紫: -3
青: -2
緑青: -1
黄: 0
明るい緑: 1
緑: 2
濃い緑: 3
"""
def reset(self) -> None:
# Grid Worldの定義
self.grid = np.array([
[
0, 0, 0, 0, -10000, 1, 1, 0, -1, -2, 0, 1, 0, -10000, -10000,
2, 3, 2, 1, 0, 0, -1, 0, -10000, 1, 0, 0, 1, 2, 3
],
[
0, -10000, -10000, 0, -1, 0, 1, -10000, -2, -3, 0, 1, 1, 0, 0,
1, 2, -10000, -10000, 0, -1, 0, 1, 1, 0, 0, -10000, 2, 2, 1
],
[
0, -1, -1, 0, 0, 1, 2, 2, 1, 0, 0, -1, -1, 0, 0,
1, -10000, -2, -3, -2, -1, 0, 0, -10000, 1, 1, 2, 3, 2, 1
],
[
1, 1, 2, 2, 0, -10000, -10000, -1, -2, -1, 0, 0, 1, 1, 0,
-1, -2, -3, -10000, 0, 0, 1, 1, 2, 2, -10000, 0, -1, 0, 1
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0,
1, 1, 0, -1, -1, 0, 0, -10000, 1, 2, 2, 1, 0, -1, 0
],
[
2, -10000, -10000, 0, -1, -2, 0, 1, 2, 3, 2, 1, 0, -1, -1,
0, 1, 1, -10000, 2, 3, 3, 2, 1, 0, 0, -1, -2, -10000, 0
],
[
1, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0, 1, 1,
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -10000, 1, 0
],
[
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -1, -2, -10000,
1, 2, 2, 1, 0, 0, -1, -2, -10000, 2, 2, 1, 0, -1, 0
],
[
0, 0, 0, 0, 0, 1, 1, -10000, 1, 1, 0, -1, -2, -2, -1,
0, 1, 1, 2, 3, 2, 1, -10000, 0, 0, -1, -2, -3, -10000, 1
],
[
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, -10000, 2, 3,
2, 1, 0, -1, -10000, 0, 0, 1, 1, 2, 2, -10000, 1, 0, -1
],
[
1, -10000, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -1, 0,
1, 2, 3, 2, 1, 0, -10000, 0, -1, -2, -1, 0, 1, 1, 2
],
[
2, 3, 3, 2, 1, 0, -1, -2, -10000, 1, 2, 1, 0, -1, -2,
-3, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0, -10000, 1, 2, 3
],
[
1, 1, 0, -1, -2, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0,
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3
],
[
0, -10000, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 0, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
],
[
0, 0, 0, 1, 2, 3, 3, 2, 1, 0, -1, -1, 0, 1, 2,
3, 3, 2, 1, 0, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0
],
[
1, 2, -10000, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2, 1,
0, -1, -1, 0, 1, 2, 2, 1, 0, 0, -1, -2, -10000, 1, 2
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2,
1, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -3, 0
],
[
1, -10000, -10000, 1, 2, 3, 2, 1, 0, -1, 0, 1, 2, 2, 1,
0, -1, -1, 0, 1, 2, 3, 3, 2, 1, 0, -10000, 1, 2, 3
],
[
0, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2, 2, 1,
0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -10000, 1, 2
],
[
0, -10000, 1, 2, 3, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
]
], dtype=int)
# 行数と列数を取得
self.rows, self.cols = self.grid.shape
# 現在地(スタート地点)を取得
self.current_pos = [0, 0]
# 環境のインスタンスを用意
env = GridWorldEnv()
# 初期化
env.reset()
# Grid Worldの初期状態の行列を出力
print(env.grid)
# Grid Worldの行数と列数を出力
print(f"行数:\t{env.rows}")
print(f"列数:\t{env.cols}")
# スタート地点(現在地)を出力
print(f"Current Position:\t{env.current_pos}")
少し長くなり見づらいかもしれませんが、問題の行列が大きすぎて嵩張っているだけなので構えなくて大丈夫です。
後は前回同様に報酬の合計を0にするように記述しましょう。
class GridWorldEnv(gym.Env):
def __init__(self):
# クラス継承のおまじない
super(GridWorldEnv, self).__init__()
# Action Spaceの定義
self.action_space = spaces.Discrete(4)
self.action_desc = """
移動方向
0: 上に進む
1: 下に進む
2: 左に進む
3: 右に進む
"""
# Observation Spaceの定義
self.observation_space = spaces.Box(
low=np.array([-10000, -10000, -10000, -10000]),
high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
)
self.observation_desc = """
報酬の種類
黒(穴): -10000
濃い紫: -3
青: -2
緑青: -1
黄: 0
明るい緑: 1
緑: 2
濃い緑: 3
"""
def reset(self) -> None:
# Grid Worldの定義
self.grid = np.array([
[
0, 0, 0, 0, -10000, 1, 1, 0, -1, -2, 0, 1, 0, -10000, -10000,
2, 3, 2, 1, 0, 0, -1, 0, -10000, 1, 0, 0, 1, 2, 3
],
[
0, -10000, -10000, 0, -1, 0, 1, -10000, -2, -3, 0, 1, 1, 0, 0,
1, 2, -10000, -10000, 0, -1, 0, 1, 1, 0, 0, -10000, 2, 2, 1
],
[
0, -1, -1, 0, 0, 1, 2, 2, 1, 0, 0, -1, -1, 0, 0,
1, -10000, -2, -3, -2, -1, 0, 0, -10000, 1, 1, 2, 3, 2, 1
],
[
1, 1, 2, 2, 0, -10000, -10000, -1, -2, -1, 0, 0, 1, 1, 0,
-1, -2, -3, -10000, 0, 0, 1, 1, 2, 2, -10000, 0, -1, 0, 1
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0,
1, 1, 0, -1, -1, 0, 0, -10000, 1, 2, 2, 1, 0, -1, 0
],
[
2, -10000, -10000, 0, -1, -2, 0, 1, 2, 3, 2, 1, 0, -1, -1,
0, 1, 1, -10000, 2, 3, 3, 2, 1, 0, 0, -1, -2, -10000, 0
],
[
1, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0, 1, 1,
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -10000, 1, 0
],
[
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -1, -2, -10000,
1, 2, 2, 1, 0, 0, -1, -2, -10000, 2, 2, 1, 0, -1, 0
],
[
0, 0, 0, 0, 0, 1, 1, -10000, 1, 1, 0, -1, -2, -2, -1,
0, 1, 1, 2, 3, 2, 1, -10000, 0, 0, -1, -2, -3, -10000, 1
],
[
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, -10000, 2, 3,
2, 1, 0, -1, -10000, 0, 0, 1, 1, 2, 2, -10000, 1, 0, -1
],
[
1, -10000, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -1, 0,
1, 2, 3, 2, 1, 0, -10000, 0, -1, -2, -1, 0, 1, 1, 2
],
[
2, 3, 3, 2, 1, 0, -1, -2, -10000, 1, 2, 1, 0, -1, -2,
-3, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0, -10000, 1, 2, 3
],
[
1, 1, 0, -1, -2, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0,
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3
],
[
0, -10000, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 0, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
],
[
0, 0, 0, 1, 2, 3, 3, 2, 1, 0, -1, -1, 0, 1, 2,
3, 3, 2, 1, 0, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0
],
[
1, 2, -10000, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2, 1,
0, -1, -1, 0, 1, 2, 2, 1, 0, 0, -1, -2, -10000, 1, 2
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2,
1, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -3, 0
],
[
1, -10000, -10000, 1, 2, 3, 2, 1, 0, -1, 0, 1, 2, 2, 1,
0, -1, -1, 0, 1, 2, 3, 3, 2, 1, 0, -10000, 1, 2, 3
],
[
0, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2, 2, 1,
0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -10000, 1, 2
],
[
0, -10000, 1, 2, 3, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
]
], dtype=int)
# 行数と列数を取得
self.rows, self.cols = self.grid.shape
# 現在地(スタート地点)を取得
self.current_pos = [0, 0]
# 報酬を0にする
self.total = 0
# 環境のインスタンスを用意
env = GridWorldEnv()
# 初期化
env.reset()
# 報酬の表示
print(env.total)
これで .reset()
メソッドは終了です!
.render()
メソッドで描画できるようにする
ここから少し難しくなってくるので、注意していてください。
前回の MyEnv
クラスでは、描画することが無かったので合計報酬を出力するのがこの .render()
メソッドの仕事でしたが、今回はGrid Worldという可視化が可能な世界を扱っているため、その名の通りレンダリングしてみましょう。
そのために、ハイパーパラメータのようなものを予め設定するのですが、これは __init__()
の中に追記しておきましょう。
# ライブラリのインポート
from matplotlib import colors
import matplotlib.pyplot as plt
class GridWorldEnv(gym.Env):
def __init__(self):
# クラス継承のおまじない
super(GridWorldEnv, self).__init__()
# Action Spaceの定義
self.action_space = spaces.Discrete(4)
self.action_desc = """
移動方向
0: 上に進む
1: 下に進む
2: 左に進む
3: 右に進む
"""
# Observation Spaceの定義
self.observation_space = spaces.Box(
low=np.array([-10000, -10000, -10000, -10000]),
high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
)
self.observation_desc = """
報酬の種類
黒(穴): -10000
濃い紫: -3
青: -2
緑青: -1
黄: 0
明るい緑: 1
緑: 2
濃い緑: 3
"""
# 壁と報酬の色分け
self.cmap = colors.ListedColormap(
[
"#000000", "#440154", "#3b528b", "#21918c",
"#fde725", "#aadc32", "#5ec962", "#2fb47c"
]
)
self.bounds = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]
self.norm = colors.BoundaryNorm(self.bounds, self.cmap.N)
def reset(self) -> None:
# Grid Worldの定義
self.grid = np.array([
[
0, 0, 0, 0, -10000, 1, 1, 0, -1, -2, 0, 1, 0, -10000, -10000,
2, 3, 2, 1, 0, 0, -1, 0, -10000, 1, 0, 0, 1, 2, 3
],
[
0, -10000, -10000, 0, -1, 0, 1, -10000, -2, -3, 0, 1, 1, 0, 0,
1, 2, -10000, -10000, 0, -1, 0, 1, 1, 0, 0, -10000, 2, 2, 1
],
[
0, -1, -1, 0, 0, 1, 2, 2, 1, 0, 0, -1, -1, 0, 0,
1, -10000, -2, -3, -2, -1, 0, 0, -10000, 1, 1, 2, 3, 2, 1
],
[
1, 1, 2, 2, 0, -10000, -10000, -1, -2, -1, 0, 0, 1, 1, 0,
-1, -2, -3, -10000, 0, 0, 1, 1, 2, 2, -10000, 0, -1, 0, 1
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0,
1, 1, 0, -1, -1, 0, 0, -10000, 1, 2, 2, 1, 0, -1, 0
],
[
2, -10000, -10000, 0, -1, -2, 0, 1, 2, 3, 2, 1, 0, -1, -1,
0, 1, 1, -10000, 2, 3, 3, 2, 1, 0, 0, -1, -2, -10000, 0
],
[
1, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0, 1, 1,
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -10000, 1, 0
],
[
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -1, -2, -10000,
1, 2, 2, 1, 0, 0, -1, -2, -10000, 2, 2, 1, 0, -1, 0
],
[
0, 0, 0, 0, 0, 1, 1, -10000, 1, 1, 0, -1, -2, -2, -1,
0, 1, 1, 2, 3, 2, 1, -10000, 0, 0, -1, -2, -3, -10000, 1
],
[
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, -10000, 2, 3,
2, 1, 0, -1, -10000, 0, 0, 1, 1, 2, 2, -10000, 1, 0, -1
],
[
1, -10000, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -1, 0,
1, 2, 3, 2, 1, 0, -10000, 0, -1, -2, -1, 0, 1, 1, 2
],
[
2, 3, 3, 2, 1, 0, -1, -2, -10000, 1, 2, 1, 0, -1, -2,
-3, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0, -10000, 1, 2, 3
],
[
1, 1, 0, -1, -2, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0,
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3
],
[
0, -10000, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 0, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
],
[
0, 0, 0, 1, 2, 3, 3, 2, 1, 0, -1, -1, 0, 1, 2,
3, 3, 2, 1, 0, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0
],
[
1, 2, -10000, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2, 1,
0, -1, -1, 0, 1, 2, 2, 1, 0, 0, -1, -2, -10000, 1, 2
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2,
1, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -3, 0
],
[
1, -10000, -10000, 1, 2, 3, 2, 1, 0, -1, 0, 1, 2, 2, 1,
0, -1, -1, 0, 1, 2, 3, 3, 2, 1, 0, -10000, 1, 2, 3
],
[
0, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2, 2, 1,
0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -10000, 1, 2
],
[
0, -10000, 1, 2, 3, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
]
], dtype=int)
# 行数と列数を取得
self.rows, self.cols = self.grid.shape
# 現在地(スタート地点)を取得
self.current_pos = [0, 0]
# 報酬を0にする
self.total = 0
追記した .cmap
, .bounds
, .norm
の3つは .render()
メソッドの内部で描画の際のパラメータとして使うだけなので、今までのようにインスタンスを作って確認ということは行いません。
このまま .render()
メソッドの内部を構成しましょう。
まずは全体の画像を描画するところまで進めます。
class GridWorldEnv(gym.Env):
def __init__(self):
# クラス継承のおまじない
super(GridWorldEnv, self).__init__()
# Action Spaceの定義
self.action_space = spaces.Discrete(4)
self.action_desc = """
移動方向
0: 上に進む
1: 下に進む
2: 左に進む
3: 右に進む
"""
# Observation Spaceの定義
self.observation_space = spaces.Box(
low=np.array([-10000, -10000, -10000, -10000]),
high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
)
self.observation_desc = """
報酬の種類
黒(穴): -10000
濃い紫: -3
青: -2
緑青: -1
黄: 0
明るい緑: 1
緑: 2
濃い緑: 3
"""
# 壁と報酬の色分け
self.cmap = colors.ListedColormap(
[
"#000000", "#440154", "#3b528b", "#21918c",
"#fde725", "#aadc32", "#5ec962", "#2fb47c"
]
)
self.bounds = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]
self.norm = colors.BoundaryNorm(self.bounds, self.cmap.N)
def reset(self) -> None:
# Grid Worldの定義
self.grid = np.array([
[
0, 0, 0, 0, -10000, 1, 1, 0, -1, -2, 0, 1, 0, -10000, -10000,
2, 3, 2, 1, 0, 0, -1, 0, -10000, 1, 0, 0, 1, 2, 3
],
[
0, -10000, -10000, 0, -1, 0, 1, -10000, -2, -3, 0, 1, 1, 0, 0,
1, 2, -10000, -10000, 0, -1, 0, 1, 1, 0, 0, -10000, 2, 2, 1
],
[
0, -1, -1, 0, 0, 1, 2, 2, 1, 0, 0, -1, -1, 0, 0,
1, -10000, -2, -3, -2, -1, 0, 0, -10000, 1, 1, 2, 3, 2, 1
],
[
1, 1, 2, 2, 0, -10000, -10000, -1, -2, -1, 0, 0, 1, 1, 0,
-1, -2, -3, -10000, 0, 0, 1, 1, 2, 2, -10000, 0, -1, 0, 1
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0,
1, 1, 0, -1, -1, 0, 0, -10000, 1, 2, 2, 1, 0, -1, 0
],
[
2, -10000, -10000, 0, -1, -2, 0, 1, 2, 3, 2, 1, 0, -1, -1,
0, 1, 1, -10000, 2, 3, 3, 2, 1, 0, 0, -1, -2, -10000, 0
],
[
1, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0, 1, 1,
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -10000, 1, 0
],
[
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -1, -2, -10000,
1, 2, 2, 1, 0, 0, -1, -2, -10000, 2, 2, 1, 0, -1, 0
],
[
0, 0, 0, 0, 0, 1, 1, -10000, 1, 1, 0, -1, -2, -2, -1,
0, 1, 1, 2, 3, 2, 1, -10000, 0, 0, -1, -2, -3, -10000, 1
],
[
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, -10000, 2, 3,
2, 1, 0, -1, -10000, 0, 0, 1, 1, 2, 2, -10000, 1, 0, -1
],
[
1, -10000, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -1, 0,
1, 2, 3, 2, 1, 0, -10000, 0, -1, -2, -1, 0, 1, 1, 2
],
[
2, 3, 3, 2, 1, 0, -1, -2, -10000, 1, 2, 1, 0, -1, -2,
-3, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0, -10000, 1, 2, 3
],
[
1, 1, 0, -1, -2, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0,
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3
],
[
0, -10000, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 0, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
],
[
0, 0, 0, 1, 2, 3, 3, 2, 1, 0, -1, -1, 0, 1, 2,
3, 3, 2, 1, 0, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0
],
[
1, 2, -10000, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2, 1,
0, -1, -1, 0, 1, 2, 2, 1, 0, 0, -1, -2, -10000, 1, 2
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2,
1, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -3, 0
],
[
1, -10000, -10000, 1, 2, 3, 2, 1, 0, -1, 0, 1, 2, 2, 1,
0, -1, -1, 0, 1, 2, 3, 3, 2, 1, 0, -10000, 1, 2, 3
],
[
0, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2, 2, 1,
0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -10000, 1, 2
],
[
0, -10000, 1, 2, 3, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
]
], dtype=int)
# 行数と列数を取得
self.rows, self.cols = self.grid.shape
# 現在地(スタート地点)を取得
self.current_pos = [0, 0]
# 報酬を0にする
self.total = 0
def render(self) -> None:
# 描画
fig, ax = plt.subplots(figsize=(15, 10), tight_layout=True)
im = ax.imshow(self.grid, cmap=self.cmap, norm=self.norm)
# 軸オフ、レイアウト調整
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(
"Grid World", fontsize=16
)
plt.show()
# 環境のインスタンスを用意
env = GridWorldEnv()
# 初期化
env.reset()
# 描画
env.render()
Grid Worldの全体像が可視化されました!
今度はスタート地点とゴール地点を分かりやすくしてみましょう。
スタート地点は .reset()
メソッドでも定めたように[0, 0]の位置であり、ゴールは一番右下としましょう。
class GridWorldEnv(gym.Env):
def __init__(self):
# クラス継承のおまじない
super(GridWorldEnv, self).__init__()
# Action Spaceの定義
self.action_space = spaces.Discrete(4)
self.action_desc = """
移動方向
0: 上に進む
1: 下に進む
2: 左に進む
3: 右に進む
"""
# Observation Spaceの定義
self.observation_space = spaces.Box(
low=np.array([-10000, -10000, -10000, -10000]),
high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
)
self.observation_desc = """
報酬の種類
黒(穴): -10000
濃い紫: -3
青: -2
緑青: -1
黄: 0
明るい緑: 1
緑: 2
濃い緑: 3
"""
# 壁と報酬の色分け
self.cmap = colors.ListedColormap(
[
"#000000", "#440154", "#3b528b", "#21918c",
"#fde725", "#aadc32", "#5ec962", "#2fb47c"
]
)
self.bounds = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]
self.norm = colors.BoundaryNorm(self.bounds, self.cmap.N)
def reset(self) -> None:
# Grid Worldの定義
self.grid = np.array([
[
0, 0, 0, 0, -10000, 1, 1, 0, -1, -2, 0, 1, 0, -10000, -10000,
2, 3, 2, 1, 0, 0, -1, 0, -10000, 1, 0, 0, 1, 2, 3
],
[
0, -10000, -10000, 0, -1, 0, 1, -10000, -2, -3, 0, 1, 1, 0, 0,
1, 2, -10000, -10000, 0, -1, 0, 1, 1, 0, 0, -10000, 2, 2, 1
],
[
0, -1, -1, 0, 0, 1, 2, 2, 1, 0, 0, -1, -1, 0, 0,
1, -10000, -2, -3, -2, -1, 0, 0, -10000, 1, 1, 2, 3, 2, 1
],
[
1, 1, 2, 2, 0, -10000, -10000, -1, -2, -1, 0, 0, 1, 1, 0,
-1, -2, -3, -10000, 0, 0, 1, 1, 2, 2, -10000, 0, -1, 0, 1
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0,
1, 1, 0, -1, -1, 0, 0, -10000, 1, 2, 2, 1, 0, -1, 0
],
[
2, -10000, -10000, 0, -1, -2, 0, 1, 2, 3, 2, 1, 0, -1, -1,
0, 1, 1, -10000, 2, 3, 3, 2, 1, 0, 0, -1, -2, -10000, 0
],
[
1, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0, 1, 1,
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -10000, 1, 0
],
[
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -1, -2, -10000,
1, 2, 2, 1, 0, 0, -1, -2, -10000, 2, 2, 1, 0, -1, 0
],
[
0, 0, 0, 0, 0, 1, 1, -10000, 1, 1, 0, -1, -2, -2, -1,
0, 1, 1, 2, 3, 2, 1, -10000, 0, 0, -1, -2, -3, -10000, 1
],
[
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, -10000, 2, 3,
2, 1, 0, -1, -10000, 0, 0, 1, 1, 2, 2, -10000, 1, 0, -1
],
[
1, -10000, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -1, 0,
1, 2, 3, 2, 1, 0, -10000, 0, -1, -2, -1, 0, 1, 1, 2
],
[
2, 3, 3, 2, 1, 0, -1, -2, -10000, 1, 2, 1, 0, -1, -2,
-3, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0, -10000, 1, 2, 3
],
[
1, 1, 0, -1, -2, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0,
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3
],
[
0, -10000, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 0, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
],
[
0, 0, 0, 1, 2, 3, 3, 2, 1, 0, -1, -1, 0, 1, 2,
3, 3, 2, 1, 0, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0
],
[
1, 2, -10000, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2, 1,
0, -1, -1, 0, 1, 2, 2, 1, 0, 0, -1, -2, -10000, 1, 2
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2,
1, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -3, 0
],
[
1, -10000, -10000, 1, 2, 3, 2, 1, 0, -1, 0, 1, 2, 2, 1,
0, -1, -1, 0, 1, 2, 3, 3, 2, 1, 0, -10000, 1, 2, 3
],
[
0, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2, 2, 1,
0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -10000, 1, 2
],
[
0, -10000, 1, 2, 3, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
]
], dtype=int)
# 行数と列数を取得
self.rows, self.cols = self.grid.shape
# 現在地(スタート地点)を取得
self.current_pos = [0, 0]
# 報酬を0にする
self.total = 0
def render(self) -> None:
# 描画
fig, ax = plt.subplots(figsize=(15, 10), tight_layout=True)
im = ax.imshow(self.grid, cmap=self.cmap, norm=self.norm)
# スタート地点:左上
ax.add_patch(
plt.Rectangle(
(0 - 0.5, 0 - 0.5), 1, 1, edgecolor="white",
linewidth=2, fill=False
)
)
ax.text(
0, 0, "S", va="center", ha="center",
fontsize=15, color="white", weight="bold"
)
# ゴール地点:右下
ax.add_patch(
plt.Rectangle(
(self.cols - 1 - 0.5, self.rows - 1 - 0.5), 1, 1,
edgecolor="orange", linewidth=2, fill=False
)
)
ax.text(
self.cols - 1, self.rows - 1, "G", va="center", ha="center",
fontsize=15, color="orange", weight="bold"
)
# 軸オフ、レイアウト調整
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(
"Grid World (Start: Top-Left, Goal: Bottom-Right)", fontsize=16
)
plt.show()
# 環境のインスタンスを用意
env = GridWorldEnv()
# 初期化
env.reset()
# 描画
env.render()
さらに現在地が確認できるように工夫してみましょう。
現在地はスタート地点やゴール地点と被る可能性があるので、被る場合は現在地を優先するように少しだけ書き換えます。
class GridWorldEnv(gym.Env):
def __init__(self):
# クラス継承のおまじない
super(GridWorldEnv, self).__init__()
# Action Spaceの定義
self.action_space = spaces.Discrete(4)
self.action_desc = """
移動方向
0: 上に進む
1: 下に進む
2: 左に進む
3: 右に進む
"""
# Observation Spaceの定義
self.observation_space = spaces.Box(
low=np.array([-10000, -10000, -10000, -10000]),
high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
)
self.observation_desc = """
報酬の種類
黒(穴): -10000
濃い紫: -3
青: -2
緑青: -1
黄: 0
明るい緑: 1
緑: 2
濃い緑: 3
"""
# 壁と報酬の色分け
self.cmap = colors.ListedColormap(
[
"#000000", "#440154", "#3b528b", "#21918c",
"#fde725", "#aadc32", "#5ec962", "#2fb47c"
]
)
self.bounds = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]
self.norm = colors.BoundaryNorm(self.bounds, self.cmap.N)
def reset(self) -> None:
# Grid Worldの定義
self.grid = np.array([
[
0, 0, 0, 0, -10000, 1, 1, 0, -1, -2, 0, 1, 0, -10000, -10000,
2, 3, 2, 1, 0, 0, -1, 0, -10000, 1, 0, 0, 1, 2, 3
],
[
0, -10000, -10000, 0, -1, 0, 1, -10000, -2, -3, 0, 1, 1, 0, 0,
1, 2, -10000, -10000, 0, -1, 0, 1, 1, 0, 0, -10000, 2, 2, 1
],
[
0, -1, -1, 0, 0, 1, 2, 2, 1, 0, 0, -1, -1, 0, 0,
1, -10000, -2, -3, -2, -1, 0, 0, -10000, 1, 1, 2, 3, 2, 1
],
[
1, 1, 2, 2, 0, -10000, -10000, -1, -2, -1, 0, 0, 1, 1, 0,
-1, -2, -3, -10000, 0, 0, 1, 1, 2, 2, -10000, 0, -1, 0, 1
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0,
1, 1, 0, -1, -1, 0, 0, -10000, 1, 2, 2, 1, 0, -1, 0
],
[
2, -10000, -10000, 0, -1, -2, 0, 1, 2, 3, 2, 1, 0, -1, -1,
0, 1, 1, -10000, 2, 3, 3, 2, 1, 0, 0, -1, -2, -10000, 0
],
[
1, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0, 1, 1,
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -10000, 1, 0
],
[
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -1, -2, -10000,
1, 2, 2, 1, 0, 0, -1, -2, -10000, 2, 2, 1, 0, -1, 0
],
[
0, 0, 0, 0, 0, 1, 1, -10000, 1, 1, 0, -1, -2, -2, -1,
0, 1, 1, 2, 3, 2, 1, -10000, 0, 0, -1, -2, -3, -10000, 1
],
[
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, -10000, 2, 3,
2, 1, 0, -1, -10000, 0, 0, 1, 1, 2, 2, -10000, 1, 0, -1
],
[
1, -10000, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -1, 0,
1, 2, 3, 2, 1, 0, -10000, 0, -1, -2, -1, 0, 1, 1, 2
],
[
2, 3, 3, 2, 1, 0, -1, -2, -10000, 1, 2, 1, 0, -1, -2,
-3, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0, -10000, 1, 2, 3
],
[
1, 1, 0, -1, -2, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0,
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3
],
[
0, -10000, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 0, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
],
[
0, 0, 0, 1, 2, 3, 3, 2, 1, 0, -1, -1, 0, 1, 2,
3, 3, 2, 1, 0, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0
],
[
1, 2, -10000, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2, 1,
0, -1, -1, 0, 1, 2, 2, 1, 0, 0, -1, -2, -10000, 1, 2
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2,
1, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -3, 0
],
[
1, -10000, -10000, 1, 2, 3, 2, 1, 0, -1, 0, 1, 2, 2, 1,
0, -1, -1, 0, 1, 2, 3, 3, 2, 1, 0, -10000, 1, 2, 3
],
[
0, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2, 2, 1,
0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -10000, 1, 2
],
[
0, -10000, 1, 2, 3, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
]
], dtype=int)
# 行数と列数を取得
self.rows, self.cols = self.grid.shape
# 現在地(スタート地点)を取得
self.current_pos = [0, 0]
# 報酬を0にする
self.total = 0
def render(self) -> None:
# 描画
fig, ax = plt.subplots(figsize=(15, 10), tight_layout=True)
im = ax.imshow(self.grid, cmap=self.cmap, norm=self.norm)
# 現在地
ax.add_patch(
plt.Rectangle(
(self.current_pos[0] - 0.5, self.current_pos[1] - 0.5), 1, 1,
edgecolor="red", linewidth=2, fill=False
)
)
ax.text(
self.current_pos[0], self.current_pos[1], "P",
va="center", ha="center", fontsize=15, color="red", weight="bold"
)
# スタート地点:左上
if (self.current_pos[0] != 0) or (self.current_pos[1] != 0):
ax.add_patch(
plt.Rectangle(
(0 - 0.5, 0 - 0.5), 1, 1, edgecolor="white",
linewidth=2, fill=False
)
)
ax.text(
0, 0, "S", va="center", ha="center",
fontsize=15, color="white", weight="bold"
)
# ゴール地点:右下
if (self.current_pos[0] != self.cols - 1) or\
(self.current_pos[1] != self.rows - 1):
ax.add_patch(
plt.Rectangle(
(self.cols - 1 - 0.5, self.rows - 1 - 0.5), 1, 1,
edgecolor="orange", linewidth=2, fill=False
)
)
ax.text(
self.cols - 1, self.rows - 1, "G", va="center", ha="center",
fontsize=15, color="orange", weight="bold"
)
# 軸オフ、レイアウト調整
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(
"Grid World (Start: Top-Left, Goal: Bottom-Right)", fontsize=16
)
plt.show()
# 環境のインスタンスを用意
env = GridWorldEnv()
# 初期化
env.reset()
# 描画
env.render()
これで .render()
メソッドは終了になります!
.step()
メソッドを作るにあたって
今回のGrid World問題が前回のジャンケンの問題と大きく違う点として、「終了」というものがあることが挙げられます。
一般的な強化学習の環境においても、「終了」という概念が存在し、それによって有限の世界で最適化できると考えられます。
Gymの環境では、 done
というBooleanのフラグを使うことで終了したかどうかを判定します。
基本的には done == False
としますが、終了のタイミングで True
にします。
それに伴い、強制終了やゴールをしたときの情報を得るために info
という辞書型の変数を用意しておきます。
これは基本的には空で大丈夫ですが、何かあったときには情報を追加します。
ただ、この info
に関しては割と自由度が高く、好きなように定めてOKです。
class GridWorldEnv(gym.Env):
def __init__(self):
# クラス継承のおまじない
super(GridWorldEnv, self).__init__()
# Action Spaceの定義
self.action_space = spaces.Discrete(4)
self.action_desc = """
移動方向
0: 上に進む
1: 下に進む
2: 左に進む
3: 右に進む
"""
# Observation Spaceの定義
self.observation_space = spaces.Box(
low=np.array([-10000, -10000, -10000, -10000]),
high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
)
self.observation_desc = """
報酬の種類
黒(穴): -10000
濃い紫: -3
青: -2
緑青: -1
黄: 0
明るい緑: 1
緑: 2
濃い緑: 3
"""
# 壁と報酬の色分け
self.cmap = colors.ListedColormap(
[
"#000000", "#440154", "#3b528b", "#21918c",
"#fde725", "#aadc32", "#5ec962", "#2fb47c"
]
)
self.bounds = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]
self.norm = colors.BoundaryNorm(self.bounds, self.cmap.N)
def reset(self) -> None:
# Grid Worldの定義
self.grid = np.array([
[
0, 0, 0, 0, -10000, 1, 1, 0, -1, -2, 0, 1, 0, -10000, -10000,
2, 3, 2, 1, 0, 0, -1, 0, -10000, 1, 0, 0, 1, 2, 3
],
[
0, -10000, -10000, 0, -1, 0, 1, -10000, -2, -3, 0, 1, 1, 0, 0,
1, 2, -10000, -10000, 0, -1, 0, 1, 1, 0, 0, -10000, 2, 2, 1
],
[
0, -1, -1, 0, 0, 1, 2, 2, 1, 0, 0, -1, -1, 0, 0,
1, -10000, -2, -3, -2, -1, 0, 0, -10000, 1, 1, 2, 3, 2, 1
],
[
1, 1, 2, 2, 0, -10000, -10000, -1, -2, -1, 0, 0, 1, 1, 0,
-1, -2, -3, -10000, 0, 0, 1, 1, 2, 2, -10000, 0, -1, 0, 1
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0,
1, 1, 0, -1, -1, 0, 0, -10000, 1, 2, 2, 1, 0, -1, 0
],
[
2, -10000, -10000, 0, -1, -2, 0, 1, 2, 3, 2, 1, 0, -1, -1,
0, 1, 1, -10000, 2, 3, 3, 2, 1, 0, 0, -1, -2, -10000, 0
],
[
1, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0, 1, 1,
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -10000, 1, 0
],
[
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -1, -2, -10000,
1, 2, 2, 1, 0, 0, -1, -2, -10000, 2, 2, 1, 0, -1, 0
],
[
0, 0, 0, 0, 0, 1, 1, -10000, 1, 1, 0, -1, -2, -2, -1,
0, 1, 1, 2, 3, 2, 1, -10000, 0, 0, -1, -2, -3, -10000, 1
],
[
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, -10000, 2, 3,
2, 1, 0, -1, -10000, 0, 0, 1, 1, 2, 2, -10000, 1, 0, -1
],
[
1, -10000, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -1, 0,
1, 2, 3, 2, 1, 0, -10000, 0, -1, -2, -1, 0, 1, 1, 2
],
[
2, 3, 3, 2, 1, 0, -1, -2, -10000, 1, 2, 1, 0, -1, -2,
-3, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0, -10000, 1, 2, 3
],
[
1, 1, 0, -1, -2, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0,
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3
],
[
0, -10000, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 0, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
],
[
0, 0, 0, 1, 2, 3, 3, 2, 1, 0, -1, -1, 0, 1, 2,
3, 3, 2, 1, 0, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0
],
[
1, 2, -10000, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2, 1,
0, -1, -1, 0, 1, 2, 2, 1, 0, 0, -1, -2, -10000, 1, 2
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2,
1, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -3, 0
],
[
1, -10000, -10000, 1, 2, 3, 2, 1, 0, -1, 0, 1, 2, 2, 1,
0, -1, -1, 0, 1, 2, 3, 3, 2, 1, 0, -10000, 1, 2, 3
],
[
0, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2, 2, 1,
0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -10000, 1, 2
],
[
0, -10000, 1, 2, 3, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
]
], dtype=int)
# 行数と列数を取得
self.rows, self.cols = self.grid.shape
# 現在地(スタート地点)を取得
self.current_pos = [0, 0]
# 報酬を0にする
self.total = 0
def render(self) -> None:
# 描画
fig, ax = plt.subplots(figsize=(15, 10), tight_layout=True)
im = ax.imshow(self.grid, cmap=self.cmap, norm=self.norm)
# 現在地
ax.add_patch(
plt.Rectangle(
(self.current_pos[0] - 0.5, self.current_pos[1] - 0.5), 1, 1,
edgecolor="red", linewidth=2, fill=False
)
)
ax.text(
self.current_pos[0], self.current_pos[1], "P",
va="center", ha="center", fontsize=15, color="red", weight="bold"
)
# スタート地点:左上
if (self.current_pos[0] != 0) or (self.current_pos[1] != 0):
ax.add_patch(
plt.Rectangle(
(0 - 0.5, 0 - 0.5), 1, 1, edgecolor="white",
linewidth=2, fill=False
)
)
ax.text(
0, 0, "S", va="center", ha="center",
fontsize=15, color="white", weight="bold"
)
# ゴール地点:右下
if (self.current_pos[0] != self.cols - 1) or\
(self.current_pos[1] != self.rows - 1):
ax.add_patch(
plt.Rectangle(
(self.cols - 1 - 0.5, self.rows - 1 - 0.5), 1, 1,
edgecolor="orange", linewidth=2, fill=False
)
)
ax.text(
self.cols - 1, self.rows - 1, "G", va="center", ha="center",
fontsize=15, color="orange", weight="bold"
)
# 軸オフ、レイアウト調整
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(
"Grid World (Start: Top-Left, Goal: Bottom-Right)", fontsize=16
)
plt.show()
def step(self) -> tuple[bool, dict]:
# 終了判定をFalseに
done = False
# 追加情報
info = {}
return done, info
# 環境のインスタンスを用意
env = GridWorldEnv()
# 初期化
env.reset()
# 終了判定と追加情報を取得
done, info = env.step()
print(f"終了判定:\t{done}")
print(f"追加情報:\t{info}")
今度は引数に行動を入れましょう。
ジャンケン環境では相手の手も入れていましたが、今回は行動以外に環境を変える要因は無いため action
のみを入れていきます。
処理の内容としては、まずは元いた地点をコピーし(後で穴に変える必要があるため)、四方向の動きごとに条件分岐させます。
分岐した中での処理は、壁にぶつからなければその方向に1マス進み、壁にぶつかれば強制終了なので done
を True
にし、 info
に何が起きたのかを記載します。
1マス進んだ先では報酬があるはずなので、それを受け取ります。
class GridWorldEnv(gym.Env):
def __init__(self):
# クラス継承のおまじない
super(GridWorldEnv, self).__init__()
# Action Spaceの定義
self.action_space = spaces.Discrete(4)
self.action_desc = """
移動方向
0: 上に進む
1: 下に進む
2: 左に進む
3: 右に進む
"""
# Observation Spaceの定義
self.observation_space = spaces.Box(
low=np.array([-10000, -10000, -10000, -10000]),
high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
)
self.observation_desc = """
報酬の種類
黒(穴): -10000
濃い紫: -3
青: -2
緑青: -1
黄: 0
明るい緑: 1
緑: 2
濃い緑: 3
"""
# 壁と報酬の色分け
self.cmap = colors.ListedColormap(
[
"#000000", "#440154", "#3b528b", "#21918c",
"#fde725", "#aadc32", "#5ec962", "#2fb47c"
]
)
self.bounds = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]
self.norm = colors.BoundaryNorm(self.bounds, self.cmap.N)
def reset(self) -> None:
# Grid Worldの定義
self.grid = np.array([
[
0, 0, 0, 0, -10000, 1, 1, 0, -1, -2, 0, 1, 0, -10000, -10000,
2, 3, 2, 1, 0, 0, -1, 0, -10000, 1, 0, 0, 1, 2, 3
],
[
0, -10000, -10000, 0, -1, 0, 1, -10000, -2, -3, 0, 1, 1, 0, 0,
1, 2, -10000, -10000, 0, -1, 0, 1, 1, 0, 0, -10000, 2, 2, 1
],
[
0, -1, -1, 0, 0, 1, 2, 2, 1, 0, 0, -1, -1, 0, 0,
1, -10000, -2, -3, -2, -1, 0, 0, -10000, 1, 1, 2, 3, 2, 1
],
[
1, 1, 2, 2, 0, -10000, -10000, -1, -2, -1, 0, 0, 1, 1, 0,
-1, -2, -3, -10000, 0, 0, 1, 1, 2, 2, -10000, 0, -1, 0, 1
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0,
1, 1, 0, -1, -1, 0, 0, -10000, 1, 2, 2, 1, 0, -1, 0
],
[
2, -10000, -10000, 0, -1, -2, 0, 1, 2, 3, 2, 1, 0, -1, -1,
0, 1, 1, -10000, 2, 3, 3, 2, 1, 0, 0, -1, -2, -10000, 0
],
[
1, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0, 1, 1,
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -10000, 1, 0
],
[
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -1, -2, -10000,
1, 2, 2, 1, 0, 0, -1, -2, -10000, 2, 2, 1, 0, -1, 0
],
[
0, 0, 0, 0, 0, 1, 1, -10000, 1, 1, 0, -1, -2, -2, -1,
0, 1, 1, 2, 3, 2, 1, -10000, 0, 0, -1, -2, -3, -10000, 1
],
[
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, -10000, 2, 3,
2, 1, 0, -1, -10000, 0, 0, 1, 1, 2, 2, -10000, 1, 0, -1
],
[
1, -10000, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -1, 0,
1, 2, 3, 2, 1, 0, -10000, 0, -1, -2, -1, 0, 1, 1, 2
],
[
2, 3, 3, 2, 1, 0, -1, -2, -10000, 1, 2, 1, 0, -1, -2,
-3, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0, -10000, 1, 2, 3
],
[
1, 1, 0, -1, -2, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0,
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3
],
[
0, -10000, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 0, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
],
[
0, 0, 0, 1, 2, 3, 3, 2, 1, 0, -1, -1, 0, 1, 2,
3, 3, 2, 1, 0, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0
],
[
1, 2, -10000, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2, 1,
0, -1, -1, 0, 1, 2, 2, 1, 0, 0, -1, -2, -10000, 1, 2
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2,
1, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -3, 0
],
[
1, -10000, -10000, 1, 2, 3, 2, 1, 0, -1, 0, 1, 2, 2, 1,
0, -1, -1, 0, 1, 2, 3, 3, 2, 1, 0, -10000, 1, 2, 3
],
[
0, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2, 2, 1,
0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -10000, 1, 2
],
[
0, -10000, 1, 2, 3, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
]
], dtype=int)
# 行数と列数を取得
self.rows, self.cols = self.grid.shape
# 現在地(スタート地点)を取得
self.current_pos = [0, 0]
# 報酬を0にする
self.total = 0
def render(self) -> None:
# 描画
fig, ax = plt.subplots(figsize=(15, 10), tight_layout=True)
im = ax.imshow(self.grid, cmap=self.cmap, norm=self.norm)
# 現在地
ax.add_patch(
plt.Rectangle(
(self.current_pos[0] - 0.5, self.current_pos[1] - 0.5), 1, 1,
edgecolor="red", linewidth=2, fill=False
)
)
ax.text(
self.current_pos[0], self.current_pos[1], "P",
va="center", ha="center", fontsize=15, color="red", weight="bold"
)
# スタート地点:左上
if (self.current_pos[0] != 0) or (self.current_pos[1] != 0):
ax.add_patch(
plt.Rectangle(
(0 - 0.5, 0 - 0.5), 1, 1, edgecolor="white",
linewidth=2, fill=False
)
)
ax.text(
0, 0, "S", va="center", ha="center",
fontsize=15, color="white", weight="bold"
)
# ゴール地点:右下
if (self.current_pos[0] != self.cols - 1) or\
(self.current_pos[1] != self.rows - 1):
ax.add_patch(
plt.Rectangle(
(self.cols - 1 - 0.5, self.rows - 1 - 0.5), 1, 1,
edgecolor="orange", linewidth=2, fill=False
)
)
ax.text(
self.cols - 1, self.rows - 1, "G", va="center", ha="center",
fontsize=15, color="orange", weight="bold"
)
# 軸オフ、レイアウト調整
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(
"Grid World (Start: Top-Left, Goal: Bottom-Right)", fontsize=16
)
plt.show()
def step(self, action: int) -> tuple[np.int64, bool, dict]:
# 終了判定をFalseに
done = False
# 追加情報
info = {}
# 現在地をコピーしてactionを元に遷移
tmp = self.current_pos.copy()
if action == 0:
if self.current_pos[1] - 1 >= 0:
self.current_pos[1] -= 1
else:
info["error"] = "wall"
done = True
elif action == 1:
if self.current_pos[1] + 1 <= self.rows - 1:
self.current_pos[1] += 1
else:
info["error"] = "wall"
done = True
elif action == 2:
if self.current_pos[0] - 1 >= 0:
self.current_pos[0] -= 1
else:
info["error"] = "wall"
done = True
else:
if self.current_pos[0] + 1 <= self.cols - 1:
self.current_pos[0] += 1
else:
info["error"] = "wall"
done = True
# 現在の報酬
reward = self.grid[self.current_pos[1]][self.current_pos[0]]
return reward, done, info
# 環境のインスタンスを用意
env = GridWorldEnv()
# 初期化
env.reset()
# 報酬と終了判定と追加情報を取得
reward, done, info = env.step(1) # 下に進む
print(f"報酬:\t{reward}")
print(f"終了判定:\t{done}")
print(f"追加情報:\t{info}")
# 描画
env.render()
# 環境のインスタンスを用意
env = GridWorldEnv()
# 初期化
env.reset()
# 報酬と終了判定と追加情報を取得
reward, done, info = env.step(2) # 左に進む
print(f"報酬:\t{reward}")
print(f"終了判定:\t{done}")
print(f"追加情報:\t{info}")
# 描画
env.render()
# 環境のインスタンスを用意
env = GridWorldEnv()
# 初期化
env.reset()
# 報酬と終了判定と追加情報を取得
reward, done, info = env.step(1) # 下に進む
reward, done, info = env.step(3) # 右に進む
print(f"報酬:\t{reward}")
print(f"終了判定:\t{done}")
print(f"追加情報:\t{info}")
# 描画
env.render()
下に進んだ場合、画像が変わって現在地が動きました。
左に進むと、画像は変わらずに壁にぶつかったことで強制終了された旨が出力されました。
一方で、下に進んで右に進むと穴がありますが終了判定にはなっていません。
また、いずれも元いた場所が変わっていません。
本当であれば穴になって黒になるはずですね。
まずはこの2点について修正を加えましょう!
class GridWorldEnv(gym.Env):
def __init__(self):
# クラス継承のおまじない
super(GridWorldEnv, self).__init__()
# Action Spaceの定義
self.action_space = spaces.Discrete(4)
self.action_desc = """
移動方向
0: 上に進む
1: 下に進む
2: 左に進む
3: 右に進む
"""
# Observation Spaceの定義
self.observation_space = spaces.Box(
low=np.array([-10000, -10000, -10000, -10000]),
high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
)
self.observation_desc = """
報酬の種類
黒(穴): -10000
濃い紫: -3
青: -2
緑青: -1
黄: 0
明るい緑: 1
緑: 2
濃い緑: 3
"""
# 壁と報酬の色分け
self.cmap = colors.ListedColormap(
[
"#000000", "#440154", "#3b528b", "#21918c",
"#fde725", "#aadc32", "#5ec962", "#2fb47c"
]
)
self.bounds = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]
self.norm = colors.BoundaryNorm(self.bounds, self.cmap.N)
def reset(self) -> None:
# Grid Worldの定義
self.grid = np.array([
[
0, 0, 0, 0, -10000, 1, 1, 0, -1, -2, 0, 1, 0, -10000, -10000,
2, 3, 2, 1, 0, 0, -1, 0, -10000, 1, 0, 0, 1, 2, 3
],
[
0, -10000, -10000, 0, -1, 0, 1, -10000, -2, -3, 0, 1, 1, 0, 0,
1, 2, -10000, -10000, 0, -1, 0, 1, 1, 0, 0, -10000, 2, 2, 1
],
[
0, -1, -1, 0, 0, 1, 2, 2, 1, 0, 0, -1, -1, 0, 0,
1, -10000, -2, -3, -2, -1, 0, 0, -10000, 1, 1, 2, 3, 2, 1
],
[
1, 1, 2, 2, 0, -10000, -10000, -1, -2, -1, 0, 0, 1, 1, 0,
-1, -2, -3, -10000, 0, 0, 1, 1, 2, 2, -10000, 0, -1, 0, 1
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0,
1, 1, 0, -1, -1, 0, 0, -10000, 1, 2, 2, 1, 0, -1, 0
],
[
2, -10000, -10000, 0, -1, -2, 0, 1, 2, 3, 2, 1, 0, -1, -1,
0, 1, 1, -10000, 2, 3, 3, 2, 1, 0, 0, -1, -2, -10000, 0
],
[
1, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0, 1, 1,
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -10000, 1, 0
],
[
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -1, -2, -10000,
1, 2, 2, 1, 0, 0, -1, -2, -10000, 2, 2, 1, 0, -1, 0
],
[
0, 0, 0, 0, 0, 1, 1, -10000, 1, 1, 0, -1, -2, -2, -1,
0, 1, 1, 2, 3, 2, 1, -10000, 0, 0, -1, -2, -3, -10000, 1
],
[
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, -10000, 2, 3,
2, 1, 0, -1, -10000, 0, 0, 1, 1, 2, 2, -10000, 1, 0, -1
],
[
1, -10000, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -1, 0,
1, 2, 3, 2, 1, 0, -10000, 0, -1, -2, -1, 0, 1, 1, 2
],
[
2, 3, 3, 2, 1, 0, -1, -2, -10000, 1, 2, 1, 0, -1, -2,
-3, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0, -10000, 1, 2, 3
],
[
1, 1, 0, -1, -2, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0,
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3
],
[
0, -10000, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 0, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
],
[
0, 0, 0, 1, 2, 3, 3, 2, 1, 0, -1, -1, 0, 1, 2,
3, 3, 2, 1, 0, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0
],
[
1, 2, -10000, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2, 1,
0, -1, -1, 0, 1, 2, 2, 1, 0, 0, -1, -2, -10000, 1, 2
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2,
1, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -3, 0
],
[
1, -10000, -10000, 1, 2, 3, 2, 1, 0, -1, 0, 1, 2, 2, 1,
0, -1, -1, 0, 1, 2, 3, 3, 2, 1, 0, -10000, 1, 2, 3
],
[
0, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2, 2, 1,
0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -10000, 1, 2
],
[
0, -10000, 1, 2, 3, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
]
], dtype=int)
# 行数と列数を取得
self.rows, self.cols = self.grid.shape
# 現在地(スタート地点)を取得
self.current_pos = [0, 0]
# 報酬を0にする
self.total = 0
def render(self) -> None:
# 描画
fig, ax = plt.subplots(figsize=(15, 10), tight_layout=True)
im = ax.imshow(self.grid, cmap=self.cmap, norm=self.norm)
# 現在地
ax.add_patch(
plt.Rectangle(
(self.current_pos[0] - 0.5, self.current_pos[1] - 0.5), 1, 1,
edgecolor="red", linewidth=2, fill=False
)
)
ax.text(
self.current_pos[0], self.current_pos[1], "P",
va="center", ha="center", fontsize=15, color="red", weight="bold"
)
# スタート地点:左上
if (self.current_pos[0] != 0) or (self.current_pos[1] != 0):
ax.add_patch(
plt.Rectangle(
(0 - 0.5, 0 - 0.5), 1, 1, edgecolor="white",
linewidth=2, fill=False
)
)
ax.text(
0, 0, "S", va="center", ha="center",
fontsize=15, color="white", weight="bold"
)
# ゴール地点:右下
if (self.current_pos[0] != self.cols - 1) or\
(self.current_pos[1] != self.rows - 1):
ax.add_patch(
plt.Rectangle(
(self.cols - 1 - 0.5, self.rows - 1 - 0.5), 1, 1,
edgecolor="orange", linewidth=2, fill=False
)
)
ax.text(
self.cols - 1, self.rows - 1, "G", va="center", ha="center",
fontsize=15, color="orange", weight="bold"
)
# 軸オフ、レイアウト調整
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(
"Grid World (Start: Top-Left, Goal: Bottom-Right)", fontsize=16
)
plt.show()
def step(self, action: int) -> tuple[np.int64, bool, dict]:
# 終了判定をFalseに
done = False
# 追加情報
info = {}
# 現在地をコピーしてactionを元に遷移
tmp = self.current_pos.copy()
if action == 0:
if self.current_pos[1] - 1 >= 0:
self.current_pos[1] -= 1
else:
info["error"] = "wall"
done = True
elif action == 1:
if self.current_pos[1] + 1 <= self.rows - 1:
self.current_pos[1] += 1
else:
info["error"] = "wall"
done = True
elif action == 2:
if self.current_pos[0] - 1 >= 0:
self.current_pos[0] -= 1
else:
info["error"] = "wall"
done = True
else:
if self.current_pos[0] + 1 <= self.cols - 1:
self.current_pos[0] += 1
else:
info["error"] = "wall"
done = True
# 現在の報酬
reward = self.grid[self.current_pos[1]][self.current_pos[0]]
# 報酬が-10000であれば強制終了、そうでなければ合計報酬に加算
if reward == -10000:
if info != {}:
info["error"] = "hole"
done = True
else:
self.total += reward
# 遷移前の地点を穴にする
self.grid[tmp[1]][tmp[0]] = -10000
return reward, done, info
# 環境のインスタンスを用意
env = GridWorldEnv()
# 初期化
env.reset()
# 報酬と終了判定と追加情報を取得
reward, done, info = env.step(1) # 下に進む
reward, done, info = env.step(3) # 右に進む
print(f"報酬:\t{reward}")
print(f"終了判定:\t{done}")
print(f"追加情報:\t{info}")
# 描画
env.render()
はい、ちゃんと直っているのが確認できました!
しかし、まだ終了判定が必要な状況があるはずです。
それは、右下のゴールに辿り着いたときになります。
ゴールに着いたら終了するように処理を加えましょう!
class GridWorldEnv(gym.Env):
def __init__(self):
# クラス継承のおまじない
super(GridWorldEnv, self).__init__()
# Action Spaceの定義
self.action_space = spaces.Discrete(4)
self.action_desc = """
移動方向
0: 上に進む
1: 下に進む
2: 左に進む
3: 右に進む
"""
# Observation Spaceの定義
self.observation_space = spaces.Box(
low=np.array([-10000, -10000, -10000, -10000]),
high=np.array([3, 3, 3, 3]), shape=(4,), dtype=int
)
self.observation_desc = """
報酬の種類
黒(穴): -10000
濃い紫: -3
青: -2
緑青: -1
黄: 0
明るい緑: 1
緑: 2
濃い緑: 3
"""
# 壁と報酬の色分け
self.cmap = colors.ListedColormap(
[
"#000000", "#440154", "#3b528b", "#21918c",
"#fde725", "#aadc32", "#5ec962", "#2fb47c"
]
)
self.bounds = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]
self.norm = colors.BoundaryNorm(self.bounds, self.cmap.N)
def reset(self) -> None:
# Grid Worldの定義
self.grid = np.array([
[
0, 0, 0, 0, -10000, 1, 1, 0, -1, -2, 0, 1, 0, -10000, -10000,
2, 3, 2, 1, 0, 0, -1, 0, -10000, 1, 0, 0, 1, 2, 3
],
[
0, -10000, -10000, 0, -1, 0, 1, -10000, -2, -3, 0, 1, 1, 0, 0,
1, 2, -10000, -10000, 0, -1, 0, 1, 1, 0, 0, -10000, 2, 2, 1
],
[
0, -1, -1, 0, 0, 1, 2, 2, 1, 0, 0, -1, -1, 0, 0,
1, -10000, -2, -3, -2, -1, 0, 0, -10000, 1, 1, 2, 3, 2, 1
],
[
1, 1, 2, 2, 0, -10000, -10000, -1, -2, -1, 0, 0, 1, 1, 0,
-1, -2, -3, -10000, 0, 0, 1, 1, 2, 2, -10000, 0, -1, 0, 1
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0,
1, 1, 0, -1, -1, 0, 0, -10000, 1, 2, 2, 1, 0, -1, 0
],
[
2, -10000, -10000, 0, -1, -2, 0, 1, 2, 3, 2, 1, 0, -1, -1,
0, 1, 1, -10000, 2, 3, 3, 2, 1, 0, 0, -1, -2, -10000, 0
],
[
1, 1, 0, -1, -2, -2, -1, 0, 1, 2, -10000, -10000, 0, 1, 1,
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -10000, 1, 0
],
[
0, -1, -2, -3, -10000, 0, 1, 2, 3, 2, 1, 0, -1, -2, -10000,
1, 2, 2, 1, 0, 0, -1, -2, -10000, 2, 2, 1, 0, -1, 0
],
[
0, 0, 0, 0, 0, 1, 1, -10000, 1, 1, 0, -1, -2, -2, -1,
0, 1, 1, 2, 3, 2, 1, -10000, 0, 0, -1, -2, -3, -10000, 1
],
[
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, -10000, 2, 3,
2, 1, 0, -1, -10000, 0, 0, 1, 1, 2, 2, -10000, 1, 0, -1
],
[
1, -10000, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -1, 0,
1, 2, 3, 2, 1, 0, -10000, 0, -1, -2, -1, 0, 1, 1, 2
],
[
2, 3, 3, 2, 1, 0, -1, -2, -10000, 1, 2, 1, 0, -1, -2,
-3, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0, -10000, 1, 2, 3
],
[
1, 1, 0, -1, -2, -10000, 1, 2, 3, 3, 2, 1, 0, -1, 0,
1, 2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3
],
[
0, -10000, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 0, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
],
[
0, 0, 0, 1, 2, 3, 3, 2, 1, 0, -1, -1, 0, 1, 2,
3, 3, 2, 1, 0, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0
],
[
1, 2, -10000, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2, 1,
0, -1, -1, 0, 1, 2, 2, 1, 0, 0, -1, -2, -10000, 1, 2
],
[
2, 3, 2, 1, 0, -1, -2, -2, -1, 0, 1, 2, 3, 3, 2,
1, 0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -3, 0
],
[
1, -10000, -10000, 1, 2, 3, 2, 1, 0, -1, 0, 1, 2, 2, 1,
0, -1, -1, 0, 1, 2, 3, 3, 2, 1, 0, -10000, 1, 2, 3
],
[
0, 1, 2, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2, 2, 1,
0, -1, -2, -3, -10000, 1, 2, 2, 1, 0, -1, -2, -10000, 1, 2
],
[
0, -10000, 1, 2, 3, 3, 2, 1, 0, -1, -2, -3, -10000, 1, 2,
2, 1, 0, -1, -2, -2, -1, 0, -10000, 1, 2, 2, 1, 0, -1
]
], dtype=int)
# 行数と列数を取得
self.rows, self.cols = self.grid.shape
# 現在地(スタート地点)を取得
self.current_pos = [0, 0]
# 報酬を0にする
self.total = 0
def render(self) -> None:
# 描画
fig, ax = plt.subplots(figsize=(15, 10), tight_layout=True)
im = ax.imshow(self.grid, cmap=self.cmap, norm=self.norm)
# 現在地
ax.add_patch(
plt.Rectangle(
(self.current_pos[0] - 0.5, self.current_pos[1] - 0.5), 1, 1,
edgecolor="red", linewidth=2, fill=False
)
)
ax.text(
self.current_pos[0], self.current_pos[1], "P",
va="center", ha="center", fontsize=15, color="red", weight="bold"
)
# スタート地点:左上
if (self.current_pos[0] != 0) or (self.current_pos[1] != 0):
ax.add_patch(
plt.Rectangle(
(0 - 0.5, 0 - 0.5), 1, 1, edgecolor="white",
linewidth=2, fill=False
)
)
ax.text(
0, 0, "S", va="center", ha="center",
fontsize=15, color="white", weight="bold"
)
# ゴール地点:右下
if (self.current_pos[0] != self.cols - 1) or\
(self.current_pos[1] != self.rows - 1):
ax.add_patch(
plt.Rectangle(
(self.cols - 1 - 0.5, self.rows - 1 - 0.5), 1, 1,
edgecolor="orange", linewidth=2, fill=False
)
)
ax.text(
self.cols - 1, self.rows - 1, "G", va="center", ha="center",
fontsize=15, color="orange", weight="bold"
)
# 軸オフ、レイアウト調整
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(
"Grid World (Start: Top-Left, Goal: Bottom-Right)", fontsize=16
)
plt.show()
def step(self, action: int) -> tuple[np.int64, bool, dict]:
# 終了判定をFalseに
done = False
# 追加情報
info = {}
# 現在地をコピーしてactionを元に遷移
tmp = self.current_pos.copy()
if action == 0:
if self.current_pos[1] - 1 >= 0:
self.current_pos[1] -= 1
else:
info["error"] = "wall"
done = True
elif action == 1:
if self.current_pos[1] + 1 <= self.rows - 1:
self.current_pos[1] += 1
else:
info["error"] = "wall"
done = True
elif action == 2:
if self.current_pos[0] - 1 >= 0:
self.current_pos[0] -= 1
else:
info["error"] = "wall"
done = True
else:
if self.current_pos[0] + 1 <= self.cols - 1:
self.current_pos[0] += 1
else:
info["error"] = "wall"
done = True
# 現在の報酬
reward = self.grid[self.current_pos[1]][self.current_pos[0]]
# 報酬が-10000であれば強制終了、そうでなければ合計報酬に加算
if reward == -10000:
if info != {}:
info["error"] = "hole"
done = True
else:
self.total += reward
# 遷移前の地点を穴にする
self.grid[tmp[1]][tmp[0]] = -10000
# ゴールに着いたら終了判定をTrueに
if self.current_pos == [self.cols - 1, self.rows - 1]:
info["clear"] = self.total
done = True
return reward, done, info
# 環境のインスタンスを用意
env = GridWorldEnv()
# 初期化
env.reset()
# ゴールまでのルートを用意
routes = [
1, 1, 1, 3, 3, 3, 3, 0, 3, 3, 3, 3, 3, 3, 1, 3, 1, 2, 2, 1, 1, 1, 1, 3, 1,
1, 1, 3, 1, 2, 2, 2, 2, 2, 1, 1, 3, 3, 3, 1, 3, 3, 3, 1, 1, 1, 3, 0, 0, 0,
0, 3, 0, 3, 1, 3, 3, 3, 3, 1, 3, 3, 3, 1, 2, 2, 1, 1, 3, 0, 3, 1, 3, 0, 0,
3, 1, 1, 1, 3, 3, 3, 3, 3
]
# 報酬と終了判定と追加情報を取得
for route in routes:
reward, done, info = env.step(route)
print(f"報酬:\t{reward}")
print(f"終了判定:\t{done}")
print(f"追加情報:\t{info}")
# 描画
env.render()
はい、見事に終了判定が出ました!
これで .step()
メソッドも完了です。
次回予告
次回は、今回作成したGrid World環境でエージェントと相互作用するような仕組みを構築します。
いよいよエージェント側の実装も進めていきますので、ここまでで分からないことが無いようにしておきましょう。
Discussion