🤖

強化学習をPytorchで実装 方策勾配法(Policy Gradient,PG法)編

2024/11/29に公開

趣旨

PG(Policy Gradient,方策勾配)法のアルゴリズム説明・コード実装をします.
実装に関しては,CartPole(倒立振子)問題とPendulum(振子)問題に取り組みます.
私自身,方策勾配法はゼロから作るDeep Learning ④で学習したため,ここにある証明,コードをベースに進めていきます.Pendulum問題に関しては,こちらの書籍で取り扱っていないです.

価値ベースと方策ベース

強化学習のアルゴリズムには,価値ベース方策ベースのアプローチがあります.

価値ベースでは,間接的に方策を学習します.具体的には,エージェントが価値関数(状態価値や行動価値)を学習し,それに基づいて最適な行動を選ぶことによって,結果的に最適な方策が導き出されます.代表的なアルゴリズムには,Q学習,SARSA,DQNがあります.Q学習の場合,更新式を見ればわかる通り,TD誤差R_t + γ\max_{a} Q(S_{t+1}, a) - Q(S_t, A_t))が最小になるようなQ関数を獲得することを目指しており,最終的には最適な方策を導き出します.DQN編をみれば大体理解できると思います.

方策ベースでは,直接的に方策を学習します.つまり,方策 π(a|s) 自体を明示的にパラメータ化しており,状態を入力したら行動が出力されるということです.エージェントは,方策勾配(つまり方策を微分して更新方向が示されたもの)を使用して,報酬を最大化するように方策を調整します.Q学習やDQNではQ値が最大の行動だけを選択しないようにε-greedy法を使用していましたが,方策勾配法では,方策 π(a|s) 自体が確率分布であるため,探索は自然に行われます.多分あんまりピンと来てないと思いますが,僕もそうでした.コードを見た方が理解は早いと思います.

この2つを組み合わせたアプローチも存在しており,現在これが最大勢力となっています.

方策勾配法(PG法)の導出

方策勾配法は方策ベースで最も重要な証明です.これが理解できれば,強化学習アルゴリズムの中でも大人気のA2CPPOの理解もそこまで難しくないでしょう.証明の手順や出てくる文字もゼロから作るDeep Learning ④と全く一緒なので,詳しく知りたければこちらの書籍を読んでください.

そもそもの目的は,最大の収益が見込める確率的な方策 π(a|s) を獲得することです.この方策 π(a|s) はニューラルネットワークでモデル化します.ニューラルネットワークのパラメータをθとすると,ニューラルネットワークによる方策 π_θ(a|s) と書けます.そして,

強化学習における学習の目的は,エージェントがエピソード中に辿る軌跡\tauにおける収益G(\tau)の最大化することです.このためエージェントは方策を調整し,収益の期待値を最大化するように学習します.これを数式で表したものがJ(\theta)という目的関数です.\thetaはエージェントの方策を決定するパラメータであり,学習ではこのパラメータを最適化していくことで、軌跡での収益を最大化していきます.

\begin{aligned} J(\theta) &= \mathbb{E} _{\tau\sim\theta}\ \lbrack G(\tau) \rbrack \end{aligned}

この目的関数を\thetaで微分し,勾配を計算することができれば,\thetaを適切な方向に更新することができます.Pr(\tau|\theta)はパラメータ\thetaの条件下で軌跡\tauが得られる確率です.微分が理解できていれば,ここあたりの導出は大丈夫でしょう.

\begin{aligned} \nabla _\theta J(\theta) &= \nabla _\theta \sum _\tau G(\tau)\ Pr(\tau|\theta)\\\\ &= \sum _\tau G(\tau) \nabla _\theta\ Pr(\tau|\theta)\ +\sum _\tau Pr(\tau|\theta) \nabla _\theta\ G(\tau)\\\\ &= \sum _\tau G(\tau) \nabla _\theta\ Pr(\tau|\theta)\\\\ &= \sum _\tau G(\tau)\ Pr(\tau|\theta) \frac{\nabla _\theta\ Pr(\tau|\theta)}{Pr(\tau|\theta)}\\\\ &= \sum _\tau G(\tau)\ Pr(\tau|\theta) \nabla _\theta\ \log{Pr(\tau|\theta)}\\\\ &= \mathbb{E} _{\tau\sim\theta}\ \lbrack G(\tau)\ \nabla _\theta \log{Pr(\tau|\theta)} \rbrack\\\\ \end{aligned}

Pr(\tau|\theta)は,以下の数式で表せます.そして,p(s_0)は初期状態の確率,\pi _\theta(a_t|s_t)は方策\pi _\thetaが状態s_tにおいて行動a_tを選択する確率,p(s _{t+1}|s_t)は状態s_tから状態s _{t+1}に遷移する確率です.

\begin{aligned} Pr(\tau|\theta) &= p(s_0) \pi _\theta(a_0|s_0)p(s_1|s_0) \cdots \pi _\theta(a_t|s_t)p(s _{t+1}|s_t)\\\\ &= p(s_0) \prod_t^T \pi _\theta(a_t|s_t)p(s _{t+1}|s_t)\\\\ \end{aligned}

Pr(\tau|\theta)の対数を取ります.

\begin{aligned} \log{Pr(\tau|\theta)} &= \log{p(s_0)}+\log{\prod_t^T\pi _\theta(a_t|s_t)}+\log{\prod_t^Tp(s _{t+1}|s_t)}\\\\ &= \log{p(s_0)}+\sum_t^T\log{\pi _\theta(a_t|s_t)}+\sum_t^T\log{p(s _{t+1}|s_t)}\\\\ \end{aligned}

\log{Pr(\tau|\theta)}を先ほどの導出に合わせて\thetaで微分します.

\begin{aligned} \nabla _\theta\log{Pr(\tau|\theta)}&= \nabla _\theta\log{p(s_0)}+\nabla _\theta\sum_t^T\log{\pi _\theta(a_t|s_t)}+\nabla _\theta\sum_t^T\log{p(s _{t+1}|s_t)}\\\\ &= \sum_t^T\nabla _\theta\log{\pi _\theta(a_t|s_t)}\\\\ \end{aligned}

ここで得られたものを先ほど微分した目的関数に代入します.ここで最終的に得られたものをパラメータ\thetaの勾配として使用します.

\begin{aligned} \nabla _\theta J(\theta) &= \mathbb{E} _{\tau\sim\theta}\ \lbrack G(\tau)\ \nabla _\theta \log{Pr(\tau|\theta)} \rbrack\\\\ &= \mathbb{E} _{\tau\sim\theta}\ \lbrack G(\tau)\ \sum_t^T\nabla _\theta\log{\pi _\theta(a_t|s_t)} \rbrack\\\\ &= \mathbb{E} _{\tau\sim\theta}\ \lbrack \sum_t^TG(\tau)\ \nabla _\theta\log{\pi _\theta(a_t|s_t)} \rbrack \end{aligned}

Pytorchで実装

CartPole問題

DQN編の繰り返しになるため,CartPole問題の説明は省略

ハイパーパラメータ

hyperparameters.py
EPISODES = 5000 # 総エピソード数
MAX_STEP = 500 # 1エピソードでの最大ステップ数
LEARNING_RATE = 0.0002 # 学習率
DISCOUNT_RATE = 0.99 # 割引率
STATE_SIZE = 4 # 状態数
ACTION_SIZE = 2 # 行動数
LOG_EPISODES = 500 # ログ出力のステップ頻度
コード説明

DQNなどの価値ベースのアルゴリズムでは,将来的な報酬の総和を算出するため,あるステップごとにパラメータの更新をしていました.一方,単純な方策勾配法では,将来的な報酬の総和ではなく,実際に獲得した報酬から収益を算出するため,エピソード単位でパラメータの更新をした方が,学習がうまくいきます.

ニューラルネットワーク

net.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self, state_size: int, action_size: int)
        super(Net, self).__init__()
        self.layer1 = nn.Linear(state_size, 64)
        self.layer2 = nn.Linear(64, 64)
        self.layer3 = nn.Linear(64, action_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.softmax(self.layer3(x), dim=1)
        
        return x
コード説明

ニューラルネットワークの構造自体は,DQNと一緒です.ただ,それぞれの行動に対する出力の和を1にするため,最後の出力層でソフトマックス関数を使用しています.

エージェント

agent.py
from typing import Tuple, Union
import torch.optim as optim
from torch.distributions import Categorical

class Agent:
    def __init__(self, state_size: int, action_size: int):
        self.gamma = DISCOUNT_RATE
        self.lr = LEARNING_RATE

        self.state_size = state_size
        self.action_size = action_size

        self.memory = []
        self.net = Net(self.state_size, self.action_size)
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr)

    def get_action(self, state: np.ndarray) -> Tuple[int, torch.Tensor]:
        state = torch.tensor(state[np.newaxis, :])
        probs = self.net(state)
        probs = probs[0]
        d = Categorical(probs)
        action = d.sample()
        log_prob = d.log_prob(action)

        return action.item(), log_prob

    def add_experience(self, reward: Union[int, float], log_prob: torch.Tensor) -> None:
        data = (reward, log_prob)
        self.memory.append(data)

    def update(self) -> None:
        gain, loss = 0, 0
        for reward, log_prob in reversed(self.memory):
            gain = reward + self.gamma * gain

        for reward, log_prob in self.memory:
            loss += -log_prob * gain

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.memory = []
    
    def save_model(self) -> None:
        torch.save(self.net.state_dict(), 'model.pth')
コード説明

Agentクラスには,行動を取得するget_actionメソッド,ネットワークを更新するupdateメソッド,経験を追加するadd_experienceメソッド,モデルを保存するsave_modelメソッドがあります.

get_actionメソッドでは,エージェントの状態から行動を決定します.probsは,ニューラルネットワークから出力された[0.6, 0.4]のようなそれぞれの行動を選択する確率と考えてよいです.torch.distributions.Categoricalを使用することで確率分布を作成していきます.Categorical().sample()で確率分布に従って行動を選択します.

updateメソッドでは,実際に獲得した報酬からエピソード全体での収益を算出します.reversedを使用することで、割引率をうまく計算に組み込むことができます.収益と対数確率からlossを算出します.

実行関数

main.py
import gymnasium as gym
from logging import getLogger, basicConfig, INFO

basicConfig(level=INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = getLogger(__name__)

def main():
    # train
    env = gym.make('CartPole-v1')
    agent = Agent(STATE_SIZE, ACTION_SIZE)
    for e in range(EPISODES):
        done = False
        state, _ = env.reset()
        step = 0
        while step < MAX_STEPS:
            action, prob = agent.get_action(state)
            next_state, reward, done, _, _ = env.step(action)
            agent.add_experience(reward, prob)

            state = next_state
            step += 1

            if done:
            break
        if e % LOG_EPISODES == 0:
            logger.info("episodes: %d", e)
        agent.update()
    agent.save_model()
    env.close()
    # test
    env = gym.make('CartPole-v1', render_mode="human")
    done = False
    state, _  = env.reset()
    step = 0
    while step < MAX_STEPS:
        env.render()
        action, _ = agent.get_action(state)
        next_state, reward, done, _, _ = env.step(action)
        state = next_state

        if done:
            break
            
if __name__ == '__main__':
    main()      
コード説明

全体の流れはDQNとほとんど同じです.前述したとおり,ステップとエピソードのどちらを軸に学習を進めていくかが異なります.

Pendulum問題

Pendulum問題は、倒立振子(垂直に立った棒)を振子のように制御し、できるだけ上向きに保つことを目指します。CartPole問題とは異なり、連続的な行動空間を持つ制御問題です。

行動空間

No. 行動 最小値 最大値
0 トルク -2.0 2.0

観測空間

No. 観測 最小値 最大値
0 x = cos(theta) -1.0 1.0
1 y = sin(theta) -1.0 1.0
2 角速度 -8.0 8.0

報酬

R_t = -(theta^2(t) + 0.1\dot{theta^2(t)} + 0.001torque(t)^2)

レコーディング-2024-09-30-170626.gif

ハイパーパラメータ

hyperparameters.py
EPISODES = 5000 # 総エピソード数
MAX_STEP = 500 # 1エピソードでの最大ステップ数
LEARNING_RATE = 0.0002 # 学習率
DISCOUNT_RATE = 0.99 # 割引率
STATE_SIZE = 3 # 状態数
ACTION_SIZE = 1 # 行動数
LOG_EPISODES = 500 # ログ出力のステップ頻度
コード説明

行動出力が離散値か連続値かで最適なハイパーパラメータが結構変わってきますが,今回はCartPole問題と同じでハイパーパラメータで進めます.

ニューラルネットワーク

net.py
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self, state_size: int, action_size: int):
        super(Net, self).__init__()
        self.layer1 = nn.Linear(state_size, 64)
        self.layer2 = nn.Linear(64, 64)
        self.layer3 = nn.Linear(64, action_size)
        self.param1 = nn.Parameter(torch.ones(action_size) * 0, requires_grad=True)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        mean_action = self.layer3(x)
        log_std = self.param1

        return mean_action, log_std

コード説明

行動が連続値の場合,行動は正規分布からサンプリングされます.この際,ニューラルネットワークの出力は正規分布の平均値と標準偏差になります.ただし,標準偏差は負の値を取ることができません.そのため,ニューラルネットワークでは標準偏差そのものでなく,その対数を出力します.対数は負の値も取ることができ,ニューラルネットワークの出力は正負の制約を受けないため,都合が良いです.

エージェント

agent.py
from typing import Union
import torch.optim as optim
from torch.distributions import Normal

class Agent:
    def __init__(self, state_size: int, action_size: int):
        self.gamma = DISCOUNT_RATE
        self.lr = LEARNING_RATE

        self.state_size = state_size
        self.action_size = action_size

        self.memory = []
        self.net = Net(self.state_size, self.action_size)
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
        
    def get_action(self, state: np.ndarray) -> Tuple[np.ndarray, torch.Tensor]:
        state = torch.tensor(state[np.newaxis, :])
        mean_action, log_std = self.net(state)
        std = torch.ones_like(mean_action) * log_std.exp()
        d = Normal(mean_action, std)
        action = d.rsample()
        log_prob = d.log_prob(action)

        return np.array([action]), log_prob

    def add_experience(self, reward: Union[int, float], log_prob: torch.Tensor) -> None:
        data = (reward, log_prob)
        self.memory.append(data)

    def update(self) -> None:
        gain, loss = 0, 0
        for reward, log_prob in reversed(self.memory):
            gain = reward + self.gamma * gain

        for reward, log_prob in self.memory:
            loss += - log_prob * gain

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.memory = [] 

    def save_model(self) -> None:
        torch.save(self.net.state_dict(), 'model.pth')
コード説明

get_actionメソッドが変わります.まず,出力された標準偏差の対数を標準偏差に変換します.次に,正規分布を作成し,行動をサンプリングします.出力が離散値の際は,d.sample()としていましたが,ここではd.rsample()としています.pytorch内のsample()とrsample()のコードは下の通りです.

sample.py
def sample(self, sample_shape=torch.Size()):
    shape = self._extended_shape(sample_shape)
    with torch.no_grad():
        return torch.normal(self.loc.expand(shape), self.scale.expand(shape))

sample()は,torch.normal()でただサンプリングしているだけです.サンプリング処理は微分不可能であり(だからwith torch.no_grad()としている),勾配がここで失われるため誤差逆伝播ができません.

rsample.py
def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
    shape = self._extended_shape(sample_shape)
    eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
    return self.loc + eps * self.scale

rsample()は,再パラメータ化トリック(Reparametrization Trick)を使用することで,誤差逆伝播を可能としています.再パラメータ化トリックでは,別の確率変数がサンプリング要素を担っています.self.loc()が平均値,self.scale()が標準偏差であり,epsが平均0、分散共分散行列が単位行列Iの正規分布に従う変数です.再パラメータ化トリックは,VAE(Variational Autoencoder)においてエンコーダで潜在変数zをサンプリングする際にも使用されます.

実行関数

main.py
import gymnasium as gym
from logging import getLogger, basicConfig, INFO

basicConfig(level=INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = getLogger(__name__)

def main():
    # train
    env = gym.make('Pendulum-v1', g=9.81)
    agent = Agent(STATE_SIZE, ACTION_SIZE)
    for e in range(EPISODES):
        done = False
        state, _ = env.reset()
        step = 0
        while step < MAX_STEPS:
            action, prob = agent.get_action(state)
            next_state, reward, done, _, _ = env.step(action)
            agent.add_experience(reward, prob)

            state = next_state
            step += 1

            if done:
            break
        if e % LOG_EPISODES == 0:
            logger.info("episodes: %d", e)
        agent.update()
    agent.save_model()
    env.close()
    # test
    env = gym.make('Pendulum-v1', g=9.81, render_mode="human")
    done = False
    state, _  = env.reset()
    step = 0
    while step < MAX_STEPS:
        env.render()
        action, _ = agent.get_action(state)
        next_state, reward, done, _, _ = env.step(action)
        state = next_state

        if done:
            break
            
if __name__ == '__main__':
    main()      
コード説明

ここは全く同じです.

まとめ

大事な箇所は,方策勾配法の導出と行動が連続値である場合のサンプリング方法です.
性能はそこまで良くないです.方策勾配法は,これの発展形であるREINFORCEやActor-Criticでより真価を発揮します.そして,Actor-Criticの発展形であるPPO(Proximal Policy Optimization)かSAC(Soft Actor-Critic)が最近ではよく使用されています.

参考文献

  1. 斎藤康毅, ゼロから作るDeep Learning④ー強化学習編, 株式会社オライリー・ジャパン

Discussion