強化学習をPytorchで実装 方策勾配法(Policy Gradient,PG法)編
趣旨
PG(Policy Gradient,方策勾配)法のアルゴリズム説明・コード実装をします.
実装に関しては,CartPole(倒立振子)問題とPendulum(振子)問題に取り組みます.
私自身,方策勾配法はゼロから作るDeep Learning ④で学習したため,ここにある証明,コードをベースに進めていきます.Pendulum問題に関しては,こちらの書籍で取り扱っていないです.
価値ベースと方策ベース
強化学習のアルゴリズムには,価値ベースと方策ベースのアプローチがあります.
価値ベースでは,間接的に方策を学習します.具体的には,エージェントが価値関数(状態価値や行動価値)を学習し,それに基づいて最適な行動を選ぶことによって,結果的に最適な方策が導き出されます.代表的なアルゴリズムには,Q学習,SARSA,DQNがあります.Q学習の場合,更新式を見ればわかる通り,TD誤差(
方策ベースでは,直接的に方策を学習します.つまり,方策
この2つを組み合わせたアプローチも存在しており,現在これが最大勢力となっています.
方策勾配法(PG法)の導出
方策勾配法は方策ベースで最も重要な証明です.これが理解できれば,強化学習アルゴリズムの中でも大人気のA2CやPPOの理解もそこまで難しくないでしょう.証明の手順や出てくる文字もゼロから作るDeep Learning ④と全く一緒なので,詳しく知りたければこちらの書籍を読んでください.
そもそもの目的は,最大の収益が見込める確率的な方策
強化学習における学習の目的は,エージェントがエピソード中に辿る軌跡
この目的関数を
ここで得られたものを先ほど微分した目的関数に代入します.ここで最終的に得られたものをパラメータ
Pytorchで実装
CartPole問題
DQN編の繰り返しになるため,CartPole問題の説明は省略
ハイパーパラメータ
EPISODES = 5000 # 総エピソード数
MAX_STEP = 500 # 1エピソードでの最大ステップ数
LEARNING_RATE = 0.0002 # 学習率
DISCOUNT_RATE = 0.99 # 割引率
STATE_SIZE = 4 # 状態数
ACTION_SIZE = 2 # 行動数
LOG_EPISODES = 500 # ログ出力のステップ頻度
コード説明
DQNなどの価値ベースのアルゴリズムでは,将来的な報酬の総和を算出するため,あるステップごとにパラメータの更新をしていました.一方,単純な方策勾配法では,将来的な報酬の総和ではなく,実際に獲得した報酬から収益を算出するため,エピソード単位でパラメータの更新をした方が,学習がうまくいきます.
ニューラルネットワーク
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, state_size: int, action_size: int)
super(Net, self).__init__()
self.layer1 = nn.Linear(state_size, 64)
self.layer2 = nn.Linear(64, 64)
self.layer3 = nn.Linear(64, action_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
x = F.softmax(self.layer3(x), dim=1)
return x
コード説明
ニューラルネットワークの構造自体は,DQNと一緒です.ただ,それぞれの行動に対する出力の和を1にするため,最後の出力層でソフトマックス関数を使用しています.
エージェント
from typing import Tuple, Union
import torch.optim as optim
from torch.distributions import Categorical
class Agent:
def __init__(self, state_size: int, action_size: int):
self.gamma = DISCOUNT_RATE
self.lr = LEARNING_RATE
self.state_size = state_size
self.action_size = action_size
self.memory = []
self.net = Net(self.state_size, self.action_size)
self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
def get_action(self, state: np.ndarray) -> Tuple[int, torch.Tensor]:
state = torch.tensor(state[np.newaxis, :])
probs = self.net(state)
probs = probs[0]
d = Categorical(probs)
action = d.sample()
log_prob = d.log_prob(action)
return action.item(), log_prob
def add_experience(self, reward: Union[int, float], log_prob: torch.Tensor) -> None:
data = (reward, log_prob)
self.memory.append(data)
def update(self) -> None:
gain, loss = 0, 0
for reward, log_prob in reversed(self.memory):
gain = reward + self.gamma * gain
for reward, log_prob in self.memory:
loss += -log_prob * gain
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.memory = []
def save_model(self) -> None:
torch.save(self.net.state_dict(), 'model.pth')
コード説明
Agentクラスには,行動を取得するget_actionメソッド,ネットワークを更新するupdateメソッド,経験を追加するadd_experienceメソッド,モデルを保存するsave_modelメソッドがあります.
get_actionメソッドでは,エージェントの状態から行動を決定します.probsは,ニューラルネットワークから出力された[0.6, 0.4]のようなそれぞれの行動を選択する確率と考えてよいです.torch.distributions.Categoricalを使用することで確率分布を作成していきます.Categorical().sample()で確率分布に従って行動を選択します.
updateメソッドでは,実際に獲得した報酬からエピソード全体での収益を算出します.reversedを使用することで、割引率をうまく計算に組み込むことができます.収益と対数確率からlossを算出します.
実行関数
import gymnasium as gym
from logging import getLogger, basicConfig, INFO
basicConfig(level=INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = getLogger(__name__)
def main():
# train
env = gym.make('CartPole-v1')
agent = Agent(STATE_SIZE, ACTION_SIZE)
for e in range(EPISODES):
done = False
state, _ = env.reset()
step = 0
while step < MAX_STEPS:
action, prob = agent.get_action(state)
next_state, reward, done, _, _ = env.step(action)
agent.add_experience(reward, prob)
state = next_state
step += 1
if done:
break
if e % LOG_EPISODES == 0:
logger.info("episodes: %d", e)
agent.update()
agent.save_model()
env.close()
# test
env = gym.make('CartPole-v1', render_mode="human")
done = False
state, _ = env.reset()
step = 0
while step < MAX_STEPS:
env.render()
action, _ = agent.get_action(state)
next_state, reward, done, _, _ = env.step(action)
state = next_state
if done:
break
if __name__ == '__main__':
main()
コード説明
全体の流れはDQNとほとんど同じです.前述したとおり,ステップとエピソードのどちらを軸に学習を進めていくかが異なります.
Pendulum問題
Pendulum問題は、倒立振子(垂直に立った棒)を振子のように制御し、できるだけ上向きに保つことを目指します。CartPole問題とは異なり、連続的な行動空間を持つ制御問題です。
行動空間
No. | 行動 | 最小値 | 最大値 |
---|---|---|---|
0 | トルク | -2.0 | 2.0 |
観測空間
No. | 観測 | 最小値 | 最大値 |
---|---|---|---|
0 | x = cos(theta) | -1.0 | 1.0 |
1 | y = sin(theta) | -1.0 | 1.0 |
2 | 角速度 | -8.0 | 8.0 |
報酬
ハイパーパラメータ
EPISODES = 5000 # 総エピソード数
MAX_STEP = 500 # 1エピソードでの最大ステップ数
LEARNING_RATE = 0.0002 # 学習率
DISCOUNT_RATE = 0.99 # 割引率
STATE_SIZE = 3 # 状態数
ACTION_SIZE = 1 # 行動数
LOG_EPISODES = 500 # ログ出力のステップ頻度
コード説明
行動出力が離散値か連続値かで最適なハイパーパラメータが結構変わってきますが,今回はCartPole問題と同じでハイパーパラメータで進めます.
ニューラルネットワーク
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, state_size: int, action_size: int):
super(Net, self).__init__()
self.layer1 = nn.Linear(state_size, 64)
self.layer2 = nn.Linear(64, 64)
self.layer3 = nn.Linear(64, action_size)
self.param1 = nn.Parameter(torch.ones(action_size) * 0, requires_grad=True)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
mean_action = self.layer3(x)
log_std = self.param1
return mean_action, log_std
コード説明
行動が連続値の場合,行動は正規分布からサンプリングされます.この際,ニューラルネットワークの出力は正規分布の平均値と標準偏差になります.ただし,標準偏差は負の値を取ることができません.そのため,ニューラルネットワークでは標準偏差そのものでなく,その対数を出力します.対数は負の値も取ることができ,ニューラルネットワークの出力は正負の制約を受けないため,都合が良いです.
エージェント
from typing import Union
import torch.optim as optim
from torch.distributions import Normal
class Agent:
def __init__(self, state_size: int, action_size: int):
self.gamma = DISCOUNT_RATE
self.lr = LEARNING_RATE
self.state_size = state_size
self.action_size = action_size
self.memory = []
self.net = Net(self.state_size, self.action_size)
self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
def get_action(self, state: np.ndarray) -> Tuple[np.ndarray, torch.Tensor]:
state = torch.tensor(state[np.newaxis, :])
mean_action, log_std = self.net(state)
std = torch.ones_like(mean_action) * log_std.exp()
d = Normal(mean_action, std)
action = d.rsample()
log_prob = d.log_prob(action)
return np.array([action]), log_prob
def add_experience(self, reward: Union[int, float], log_prob: torch.Tensor) -> None:
data = (reward, log_prob)
self.memory.append(data)
def update(self) -> None:
gain, loss = 0, 0
for reward, log_prob in reversed(self.memory):
gain = reward + self.gamma * gain
for reward, log_prob in self.memory:
loss += - log_prob * gain
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.memory = []
def save_model(self) -> None:
torch.save(self.net.state_dict(), 'model.pth')
コード説明
get_actionメソッドが変わります.まず,出力された標準偏差の対数を標準偏差に変換します.次に,正規分布を作成し,行動をサンプリングします.出力が離散値の際は,d.sample()としていましたが,ここではd.rsample()としています.pytorch内のsample()とrsample()のコードは下の通りです.
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
with torch.no_grad():
return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
sample()は,torch.normal()でただサンプリングしているだけです.サンプリング処理は微分不可能であり(だからwith torch.no_grad()としている),勾配がここで失われるため誤差逆伝播ができません.
def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
return self.loc + eps * self.scale
rsample()は,再パラメータ化トリック(Reparametrization Trick)を使用することで,誤差逆伝播を可能としています.再パラメータ化トリックでは,別の確率変数がサンプリング要素を担っています.self.loc()が平均値,self.scale()が標準偏差であり,epsが平均0、分散共分散行列が単位行列
実行関数
import gymnasium as gym
from logging import getLogger, basicConfig, INFO
basicConfig(level=INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = getLogger(__name__)
def main():
# train
env = gym.make('Pendulum-v1', g=9.81)
agent = Agent(STATE_SIZE, ACTION_SIZE)
for e in range(EPISODES):
done = False
state, _ = env.reset()
step = 0
while step < MAX_STEPS:
action, prob = agent.get_action(state)
next_state, reward, done, _, _ = env.step(action)
agent.add_experience(reward, prob)
state = next_state
step += 1
if done:
break
if e % LOG_EPISODES == 0:
logger.info("episodes: %d", e)
agent.update()
agent.save_model()
env.close()
# test
env = gym.make('Pendulum-v1', g=9.81, render_mode="human")
done = False
state, _ = env.reset()
step = 0
while step < MAX_STEPS:
env.render()
action, _ = agent.get_action(state)
next_state, reward, done, _, _ = env.step(action)
state = next_state
if done:
break
if __name__ == '__main__':
main()
コード説明
ここは全く同じです.
まとめ
大事な箇所は,方策勾配法の導出と行動が連続値である場合のサンプリング方法です.
性能はそこまで良くないです.方策勾配法は,これの発展形であるREINFORCEやActor-Criticでより真価を発揮します.そして,Actor-Criticの発展形であるPPO(Proximal Policy Optimization)かSAC(Soft Actor-Critic)が最近ではよく使用されています.
参考文献
- 斎藤康毅, ゼロから作るDeep Learning④ー強化学習編, 株式会社オライリー・ジャパン
Discussion