🦀

[深層強化学習] RustでDQN (Deep Q Network) をフルスクラッチで実装してみた

2025/03/05に公開

※ 2025/03/01にqiitaに投稿した記事を同じ内容でzennに投稿してます。

はじめに

個人開発しているRust製の強化学習フレームワーク(ReinforceX)にDQNを実装したので、それについて解説する記事です。
gitレポジトリはこちらになります(もしこの記事が良かったらStarを付けてもらえると、やる気がでます!)。
また、crates.ioにもreinforcexという名のクレートとして公開しています。

本記事でやること

  • 実装したDQNでCartPoleを学習してみる
  • APIについての解説
  • DQNの実装についての解説(下記の実装も含む)
    • ε-greedy法
    • Replay Buffer

本記事でやらないこと

  • ニューラルネットワーク本体の実装
    • PyTorchのRustバインディングであるtchクレートを使っています
  • 強化学習やDQNの基礎理論に関する解説(他に良い解説記事がたくさんあるので、そちらを参照してください)
    • 強化学習の基礎知識に不安がある方は、例えばこちらを参照していただけると、本記事もスムーズに読めると思います。

なぜRustなのか

  • 所有権システムによりメモリ使用量を抑えられる
    • メモリ使用量の予測も立てやすくなり、メモリリークを防げる
    • Pythonの場合、ガベージコレクションが走るまでメモリ解放されないことがあり、瞬間的にメモリ使用量が上がりメモリリークを引き起こすリスクが付きまとっている
  • 並列学習においてメモリの安全性が高い
    • 強化学習では並列学習をするケースもあり、RustのArcなどの仕組みでメモリ安全に実行できる
  • コンパイラ型言語なのでシステムに統合しやすい
    • Pythonの場合、モデルの推論をするためにわざわざPythonのランタイムを用意してendpointを作らなければならない(多言語から呼び出す場合)
    • ONNXを使う手もあるが、未対応の関数やモデルが存在する
  • シンプルに実行速度が速い
    • ステップ軸ではなく現実の時間軸での学習速度の速さは強化学習する上で何気に重要
    • 環境もできればRustで書くと良い

CartPoleを学習させてみる

CartPole環境はPythonのライブラリであるgymnasiumをRustから呼んで使っています。なので今回のサンプルを動かすためには、Python環境を構築しgymnasiumpip installする必要があります。Pythonのバージョンは3.11を用いました。また、gymnasiumのバージョンは0.26.3です。また、Rustから呼べるようにPythonのPathを環境変数に追加しておきましょう。

$ pip install gymnasium==0.26.3

続いて、Rust1.85.0の環境を用意してください。
rustc --versionを実行して以下のように表示されていれば問題ないです。

$ rustc --version
rustc 1.85.0 (4d91de4e4 2025-02-17)

そして、サンプル用の新たなprojectを作成しましょう。

$ cargo new rust_rl_sample
$ cd rust_rl_sample

その後、Cargo.tomlのdependenciesを以下のように設定します。

[dependencies]
reinforcex = "0.0.3"
tch = {version="0.18.0"}
gym = { git = "https://github.com/kakky-hacker/gym-rs.git", branch = "master", version="2.2.1" }

main.rsに以下のコードをコピペしたら準備完了です。
cargo runで早速動かしてみましょう!

use gym::client::MakeOptions;
extern crate gym;
use gym::Action;
use reinforcex::agents::{BaseAgent, DQN};
use reinforcex::explorers::EpsilonGreedy;
use reinforcex::models::FCQNetwork;
use tch::{nn, nn::OptimizerConfig, Device, Kind, Tensor};

fn main() {
    println!("train_cartpole_with_dqn");

    let device = Device::cuda_if_available();
    let vs = nn::VarStore::new(device);
    let n_input_channels = 4;
    let action_size = 2;
    let n_hidden_layers = 2;
    let n_hidden_channels = Some(128);

    let model = Box::new(FCQNetwork::new(
        &vs,
        n_input_channels,
        action_size,
        n_hidden_layers,
        n_hidden_channels,
    ));

    let optimizer = nn::Adam::default().build(&vs, 3e-4).unwrap();
    let explorer = EpsilonGreedy::new(0.5, 0.1, 50000);
    let gamma = 0.97;
    let n_steps = 3;
    let batchsize = 16;
    let update_interval = 8;
    let target_update_interval = 100;
    let replay_buffer_capacity = 2000;

    let mut agent = DQN::new(
        model,
        optimizer,
        action_size as usize,
        batchsize,
        replay_buffer_capacity,
        update_interval,
        target_update_interval,
        Box::new(explorer),
        gamma,
        n_steps,
    );

    let gym = gym::client::GymClient::default();
    let env = gym
        .make(
            "CartPole-v1",
            Some(MakeOptions {
                render_mode: Some(gym::client::RenderMode::Human),
                ..Default::default()
            }),
        )
        .expect("Unable to create environment");

    let mut total_reward = 0.0;
    let mut total_steps = 0;
    let log_interval = 100;
    let max_step = 500;
    let max_episode = 10000;
    for episode in 1..max_episode {
        env.reset(None).unwrap();
        let mut reward = 0.0;
        let mut obs = vec![0.0; 4];
        for step in 1..max_step {
            let obs_ = Tensor::from_slice(&obs).to_kind(Kind::Float);
            let action_;
            action_ = agent.act_and_train(&obs_, reward);
            let state = env
                .step(&Action::Discrete(action_.int64_value(&[]) as usize))
                .unwrap();
            obs = state.observation.get_box().unwrap().to_vec();
            if step % 20 == 0 {
                reward = 5.0;
            } else {
                reward = 0.0;
            }
            if state.is_done || step == max_step {
                let obs_ = Tensor::from_slice(&obs).to_kind(Kind::Float);
                if step != max_step {
                    reward = -30.0;
                }
                agent.stop_episode_and_train(&obs_, reward);
                break;
            }
            env.render();
            total_reward += reward;
            total_steps += 1;
        }
        if episode % log_interval == 0 {
            println!(
                "{} episode, average reward:{}, average steps:{}",
                episode,
                total_reward / log_interval as f64,
                total_steps / log_interval,
            );
            total_reward = 0.0;
            total_steps = 0;
        }
    }
    env.close();
}

以下の画像のようにCartPoleの学習が始まると思います。200episodeほど学習すると、安定して棒を立たせられるようになってきます。

Windowsのノートパソコン(corei7-10870H, 16GB, GPU未使用)で学習させましたが、かなりスムーズでした!また、メモリも約100MBほどの使用量でした。

API解説

Agentのインターフェースは以下になっています。
訓練時は毎ステップact_and_trainを呼び、エピソードの最後のステップでstop_episode_and_trainが呼ばれる想定です。
推論時は毎ステップactを呼び、actはQ値が最大の行動を返します。
※メソッド名はPFNが開発していたChainerRLを参考にしました。訓練に使うメソッドは語尾にtrainが付いており、わかりやすいです。

fn act(&self, obs: &Tensor) -> Tensor;
fn act_and_train(&mut self, obs: &Tensor, reward: f64) -> Tensor;
fn stop_episode_and_train(&mut self, obs: &Tensor, reward: f64);

また、例えばDQNをインスタンス化する際、コンストラクタに与える引数は以下です。

  • model
    • Q関数のモデル
    • ニューラルネットワーク(以下、NN)の本体はこのmodelに含まれます
    • 様々なmodelに対応するため、dynを使って動的ディスパッチとしています
  • optimizer
    • 最適化関数
  • action_size
    • 行動空間のサイズ(選択肢の数)
  • batch_size
    • modelの重み更新時にReplay Bufferからサンプリングする経験の個数
  • replay_buffer_capacity
    • ReplayBufferに格納可能な経験の最大個数
    • 溢れたら古い経験から削除されます
  • update_interval
    • modelの重み更新を行うstep間隔
  • target_update_interval
    • ターゲットネットワークと学習対象のmodelの重みを同期するstep間隔
  • explorer
    • 探索アルゴリズム
    • EpsilonGreedyなどのインスタンスが入る
    • modelと同じで動的ディスパッチを使っています。こうすることで、ユーザが独自に定義した探索アルゴリズムも使用可能です
  • gamma
    • 割引率
  • n_steps
    • n-step TD法のnの値
pub fn new(
    model: Box<dyn BaseQFunction>,
    optimizer: nn::Optimizer,
    action_size: usize,
    batch_size: usize,
    replay_buffer_capacity: usize,
    update_interval: usize,
    target_update_interval: usize,
    explorer: Box<dyn BaseExplorer>,
    gamma: f64,
    n_steps: usize,
) -> Self {
    let target_model = model.clone();
    DQN {
        model,
        optimizer,
        replay_buffer: ReplayBuffer::new(replay_buffer_capacity, n_steps),
        explorer,
        action_size,
        batch_size,
        update_interval,
        target_model,
        target_update_interval,
        gamma,
        n_steps,
        t: 0,
    }
}

実装詳細

続いては、DQNの中身を見ていきましょう。DQN全体のコードはこちらにあります。
先ずは一番簡単なactメソッドです。推論時に使われる想定でなのでtch::no_gradで勾配計算を無効化し、let q_values = self.model.forward(&state);で各行動のQ値を計算しています。ここで、q_valuestch::Tensor型です。そして、tch::Tensor::argmaxで要素の値が最大となるindexを求めてreturnしています。

fn act(&self, obs: &Tensor) -> Tensor {
    tch::no_grad(|| {
        let state = batch_states(&vec![obs.shallow_clone()], self.model.is_cuda());
        let q_values = self.model.forward(&state);
        q_values.argmax(1, false).to_device(Device::Cpu)
    })
}

次は、act_and_trainメソッドです。actと比べると大分長いですが上から順に処理を見ていきましょう。
actと同じようにq_valuesを求めた後(こちらは学習時に呼ばれる想定なのでtch::no_gradを使っていません)、greedy_action_funcおよびrandom_action_funcを定義しています。
これらは探索を司るexplorerに渡すための関数で、greedy_action_funcは貪欲な(Q値が最大の)行動を返し、random_action_funcは行動空間からランダムに行動を選択して返します。
その後、self.explorer.select_actionに上記の関数および「現在のステップ数を表すself.t」を入力して、このステップで選択する行動を受け取ります。self.tを入力する理由は、ε-greedyのεの値をステップが進むにつれ小さくしていくためです(アニーリングのような感じです)。
次に、self.replay_bufferに「状態-行動-報酬」を記録しています。
そして、モデルの重み更新(self._updateメソッド)を実行するかどうか、および、ターゲットネットワークとの同期(_sync_target_modelメソッド)を実行するかどうかをself.tより判定しています。self._updateメソッドの中身は重要なので後ほど解説します。
最後にactionを返して本メソッドは終わりです。

fn act_and_train(&mut self, obs: &Tensor, reward: f64) -> Tensor {
    self.t += 1;
    let state = batch_states(&vec![obs.shallow_clone()], self.model.is_cuda());
    let q_values = self.model.forward(&state);

    let greedy_action_func = || q_values.argmax(1, false).int64_value(&[0]) as usize;
    let random_action_func = || rand::random::<usize>() % self.action_size;

    let action_idx = self.explorer.select_action(self.t, &random_action_func, &greedy_action_func);
    let action = Tensor::from_slice(&[action_idx as i64]).detach().to_device(Device::Cpu);

    self.replay_buffer.append(
        state,
        Some(action.shallow_clone()),
        reward,
        false,
        self.gamma,
    );
    if self.t % self.update_interval == 0 {
        self._update();
    }
    if self.t % self.target_update_interval == 0 {
        self._sync_target_model();
    }
    action
}

次は、stop_episode_and_trainメソッドです。このメソッドはエピソード終了時に呼ばれる想定なので行動選択をする必要はなく、単に状態をself.replay_bufferに記録して終了しています。

fn stop_episode_and_train(&mut self, obs: &Tensor, reward: f64) {
    let state = batch_states(&vec![obs.shallow_clone()], self.model.is_cuda());
    self.replay_buffer.append(state, None, reward, true, self.gamma);
}

では、self._updateメソッドの中身を見ていきましょう。まず、Q値の更新式は以下のようになっています。各項の計算を関数で分けています。

Q(s_t, a_t) \leftarrow Q(s_t, a_t) + \alpha \left( \sum_{k=0}^{n-1} \gamma^k r_{t+k} + \gamma^n \max_{a} Q(s_{t+n}, a) - Q(s_t, a_t) \right)

以下がself._updateメソッドの全体像です。大まかな流れとしては、Replay Bufferから経験をサンプリング→Q値を求める→モデルの予測Q値を求める→損失を計算する→重みを更新する、といった流れになっています。

fn _update(&mut self) {
    if self.replay_buffer.len() < self.batch_size {
        return;
    }
    let experiences = self.replay_buffer.sample(self.batch_size);
    let mut states: Vec<Tensor> = vec![];
    let mut n_step_after_states: Vec<Tensor> = vec![];
    let mut actions: Vec<Tensor> = vec![];
    let mut n_step_discounted_rewards: Vec<f64> = vec![];
    for experience in experiences {
        let state = experience.state.shallow_clone();
        let n_step_after_state = experience.n_step_after_experience.borrow().as_ref().unwrap().state.shallow_clone();
        let action = experience.action.as_ref().unwrap().shallow_clone();
        let n_step_discounted_reward = experience.n_step_discounted_reward.borrow().unwrap_or(experience.reward_for_this_state);
        states.push(state);
        n_step_after_states.push(n_step_after_state);
        actions.push(action);
        n_step_discounted_rewards.push(n_step_discounted_reward);
    }
    let q_values = self._compute_q_values(&n_step_after_states, n_step_discounted_rewards);
    let pred_q_values = self._compute_pred_q_values(&states, actions);
    let loss = self._compute_loss(q_values, pred_q_values);
    self.optimizer.zero_grad();
    loss.backward();
    self.optimizer.step();
}

_updateメソッドで使っている各関数について見ていきます。先ずは、_compute_q_valuesメソッドです。
このメソッドでは教師となるQ値を求めています。引数のn_step_discounted_rewardsは、ある状態からnステップ分の報酬の割引率を考慮した総和です。n_step_discounted_rewardsの各要素を式で表すと以下です。

\sum_{k=0}^{n-1} \gamma^k r_{t+k}

ここにターゲットネットワークが予測する「nステップ後の状態における最大Q値」を割引率を考慮して足したものが推論すべきQ値となります。
Q値の更新式において、_compute_q_valuesメソッドでは以下の値を計算しています。

\sum_{k=0}^{n-1} \gamma^k r_{t+k} + \gamma^n \max_{a} Q(s_{t+n}, a)
fn _compute_q_values(
    &self,
    n_step_after_states: &Vec<Tensor>,
    n_step_discounted_rewards: Vec<f64>,
) -> Tensor {
    assert_eq!(n_step_after_states.len(), n_step_discounted_rewards.len());
    let _states = batch_states(n_step_after_states, self.model.is_cuda());
    let pred_q_values = self.target_model.forward(&_states);
    let max_q_values = pred_q_values.max_dim(1, false).0;
    let gamma_n = self.gamma.powi(self.n_steps as i32);
    let n_step_discounted_rewards_tensor = Tensor::from_slice(&n_step_discounted_rewards);
    let updated_q_values = max_q_values * gamma_n + n_step_discounted_rewards_tensor;
    updated_q_values
}

次は、_compute_pred_q_valuesメソッドです。このメソッドでは現在のself.modelが予測するQ値を求めています。Q値の更新式において以下に該当する値です。

Q(s_t, a_t)
fn _compute_pred_q_values(&self, states: &Vec<Tensor>, actions: Vec<Tensor>) -> Tensor {
    assert_eq!(states.len(), actions.len());
    let _states = batch_states(states, self.model.is_cuda());
    let pred_q_values = self.model.forward(&_states);
    let actions = Tensor::stack(&actions, 0).to_kind(tch::Kind::Int64);
    let pred_q_values_selected = pred_q_values.gather(1, &actions, false).squeeze();
    pred_q_values_selected
}

そして、_compute_lossメソッドです。現状は平均二乗誤差を使っていますが、Huber損失も実装予定です。

fn _compute_loss(&self, q_values: Tensor, pred_q_values: Tensor) -> Tensor {
    let loss = (q_values - pred_q_values).square().mean(tch::Kind::Float);
    loss
}

以上が_updateメソッドの中身の解説です。
ちなみにn_step_discounted_rewards_updateメソッド内ではなく敢えてReplayBuffer内で、経験が記録された際に計算する仕様としています。こうすることにより_updateメソッドで行う計算を若干減らすことができます。
Replay Buffer内でn_step_discounted_rewardsを計算するコードは以下です。cumsum_revは、与えられた配列に対して逆順に重み付きの累積和を求めます。ここでいう重みというのは割引率のことです。

if !self.last_n_experiences.is_empty() {
    let mut rewards: Vec<f64> = self.last_n_experiences.clone().into_iter().skip(1).map(|e| e.reward_for_this_state).collect();
    rewards.push(reward);
    let n_step_discounted_reward = cumsum::cumsum_rev(&rewards, gamma)[0];
    *self.last_n_experiences.front_mut().n_step_discounted_reward.borrow_mut() = Some(n_step_discounted_reward);
    *self.last_n_experiences.front_mut().n_step_after_experience.borrow_mut() = Some(Rc::clone(&experience));
}

最後にEpsilonGreedyのコードを見ていきましょう。短いので全体を張ってしまいます。ε-greedy法を簡潔に説明すると、εの確率でランダムに行動を選択し、1-εの確率で貪欲な行動を選択する手法です。ε-greedy法の具体的な計算式は以下です。Aは行動空間を表し、∣A∣は行動空間のサイズです。

\pi(action | state) = \begin{cases} \frac{1 - \varepsilon}{|\mathcal{A}^*(state)|} + \frac{\varepsilon}{|\mathcal{A}|}, & \text{if } a \in \mathcal{A}^*(state) \\ \frac{\varepsilon}{|\mathcal{A}|}, & \text{otherwise} \end{cases}
use super::base_explorer::BaseExplorer;
use rand::Rng;

pub struct EpsilonGreedy {
    start_epsilon: f64,
    end_epsilon: f64,
    decay_steps: usize,
}

impl EpsilonGreedy {
    pub fn new(start_epsilon: f64, end_epsilon: f64, decay_steps: usize) -> Self {
        assert!((0.0..=1.0).contains(&start_epsilon));
        assert!((0.0..=1.0).contains(&end_epsilon));
        assert!(decay_steps >= 0);
        EpsilonGreedy {
            start_epsilon,
            end_epsilon,
            decay_steps,
        }
    }
}

impl BaseExplorer for EpsilonGreedy {
    fn select_action(
        &self,
        t: usize,
        random_action_func: &dyn Fn() -> usize,
        greedy_action_func: &dyn Fn() -> usize,
    ) -> usize {
        let epsilon;
        if t > self.decay_steps {
            epsilon = self.end_epsilon
        } else {
            epsilon = self.start_epsilon
                + (self.end_epsilon - self.start_epsilon) * (t as f64 / self.decay_steps as f64)
        }

        let action = if rand::thread_rng().gen::<f64>() < epsilon {
            (random_action_func)()
        } else {
            greedy_action_func()
        };

        action
    }
}

おわりに

Rust ✖ 強化学習はまだまだ多くの可能性を秘めています!この記事を読んで、Rust ✖ 強化学習という「掛け算」に興味を持ってくれる方が1人でも増えれば嬉しいです。
git
crates.io

Discussion