😽

thunder本をRustで実装する(3章)

2025/01/14に公開

概要

本記事では、ゲームAIや探索アルゴリズムの入門書として、また AHC の導入としてもよく知られる、「ゲームで学ぶ探索アルゴリズム実践入門~木探索とメタヒューリスティクス」(通称「thunder本」)の3章に関して、Rustで実装したコードと簡単なメモを備忘録的に記します。Rust コードは書籍の C++ のコードと全く同じ処理ではないですが、重要なロジック部分に関してはほとんど同等の内容になっているかと思います。

本記事の筆者は、Rust及び本記事で扱うアルゴリズムいずれも入門レベルであるため、誤りなどがあればご指摘いただけますと幸いです。

thunder本3章に登場するアルゴリズム概要

3章では、「文脈のある[1]」一人ゲームを題材として、「貪欲法」「ビームサーチ」「chokudaiサーチ」がといった探索アルゴリズムが紹介されています。

詳細な解説は書籍に委ねます[2]が、この種のゲームでは「ビームサーチ」や「chokudaiサーチ」が有効であることが多く、また複雑な探索手法を実装する前に、単純に直近一手の最適手を選択することを繰り返す「貪欲法」でベースラインをとっておくのも有用です。

「chokudaiサーチ」は「ビームサーチ」の一種といえますが、探索の「多様性」を自動的に確保しやすいという特徴があります。一方で、「ビームサーチ」はうまくチューニングすると「chokudaiサーチ」に比べてパフォーマンスが出やすいといわれています。書籍では、最初は「chokudaiサーチ」を使い、探索における多様性確保に慣れてきたら「ビームサーチ」を活用することがお勧めされています。

実装

Rust コード

まずはゲームの状態を保持する構造体と、ランダムウォークや貪欲法の実装部分です。

maze_state.rs
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};

#[derive(Clone, Eq, PartialEq)]
pub struct Coord {
    pub x: usize,
    pub y: usize,
}

const H: usize = 30;
const W: usize = 30;
const END_TURN: i32 = 100;

#[derive(Clone, Eq, PartialEq)]
pub struct MazeState {
    points: Vec<Vec<i32>>,
    turn: i32,
    pub character: Coord,
    pub game_score: i32,
    pub evaluated_score: i32,
    pub first_action: isize,
}

impl Ord for MazeState {
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        self.evaluated_score.cmp(&other.evaluated_score)
    }
}

impl PartialOrd for MazeState {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        Some(self.cmp(other))
    }
}

impl MazeState {
    const DX: [isize; 4] = [1, -1, 0, 0];
    const DY: [isize; 4] = [0, 0, 1, -1];

    pub fn new(seed: u64) -> Self {
        let mut rng = StdRng::seed_from_u64(seed);
        let x = rng.gen_range(0..H);
        let y = rng.gen_range(0..W);

        let points = (0..H)
            .map(|row_idx| {
                (0..W)
                    .map(|col_idx| {
                        if row_idx == x && col_idx == y {
                            0
                        } else {
                            rng.gen_range(0..10)
                        }
                    })
                    .collect::<Vec<i32>>()
            })
            .collect::<Vec<Vec<i32>>>();

        Self {
            points,
            turn: 0,
            character: Coord { x, y },
            game_score: 0,
            evaluated_score: 0,
            first_action: -1,
        }
    }

    pub fn is_done(&self) -> bool {
        self.turn == END_TURN
    }

    fn is_in_board(&self, x: isize, y: isize) -> bool {
        0 <= x && x < H as isize && 0 <= y && y < W as isize
    }

    pub(crate) fn legal_actions(&self) -> Vec<usize> {
        let mut actions = vec![];
        for action in 0..4 {
            let nx = self.character.x as isize + Self::DX[action];
            let ny = self.character.y as isize + Self::DY[action];
            if self.is_in_board(nx, ny) {
                actions.push(action);
            }
        }
        actions
    }

    pub fn advance(&mut self, action: usize) {
        self.character.x = (self.character.x as isize + Self::DX[action]) as usize;
        self.character.y = (self.character.y as isize + Self::DY[action]) as usize;
        let point = &mut self.points[self.character.x][self.character.y];
        self.game_score += *point;
        *point = 0;
        self.turn += 1;
    }

    pub fn to_string(&self) -> String {
        let mut result = String::new();

        result.push_str(&format!(
            "Turn: {}, Score: {}\n",
            self.turn, self.game_score
        ));

        for row in 0..H {
            for col in 0..W {
                if row == self.character.x && col == self.character.y {
                    result.push('@');
                } else if self.points[row][col] > 0 {
                    let val = self.points[row][col];
                    result.push_str(&val.to_string());
                } else {
                    result.push('.');
                }
                result.push(' ');
            }
            result.push('\n');
        }

        result
    }

    pub fn evaluate_score(&mut self) {
        self.evaluated_score = self.game_score
    }
}

#[allow(dead_code)]
pub fn random_action(state: &MazeState) -> usize {
    let legal_actions = state.legal_actions();
    let mut rng = rand::thread_rng();
    legal_actions[rng.gen_range(0..legal_actions.len())]
}

pub fn greedy_action(state: &MazeState) -> usize {
    let legal_actions = state.legal_actions();
    let mut max_score = i32::MIN;
    let mut best_action = legal_actions[0];
    for action in legal_actions {
        let mut current_state = state.clone();
        current_state.advance(action);
        current_state.evaluate_score();
        if current_state.evaluated_score > max_score {
            max_score = current_state.evaluated_score;
            best_action = action;
        }
    }

    best_action
}

pub fn play_game<F>(seed: u64, policy: F, verbose: bool) -> i32
where
    F: Fn(&MazeState) -> usize,
{
    let mut state = MazeState::new(seed);
    if verbose {
        println!("{}", state.to_string());
    }
    while !state.is_done() {
        let action = policy(&state);
        state.advance(action);
        if verbose {
            println!("{}", state.to_string());
        }
    }
    state.game_score
}

続いて、探索時間を制限するための TimeKeeper 構造体です。

time_keeper.rs
use std::time::Instant;

pub struct TimeKeeper {
    start_time: Instant,
    time_threshold_ms: i64,
}

impl TimeKeeper {
    pub fn new(time_threshold_ms: i64) -> Self {
        Self {
            start_time: Instant::now(),
            time_threshold_ms,
        }
    }

    pub fn is_time_over(&self) -> bool {
        let elapsed_ms = self.start_time.elapsed().as_millis() as i64;
        elapsed_ms >= self.time_threshold_ms
    }
}

次はビームサーチの実装です。

beam_search.rs
use crate::maze_state::MazeState;
use crate::time_keeper::TimeKeeper;
use std::collections::BinaryHeap;

const BEAM_WIDTH: usize = 4;
const BEAM_DEPTH: usize = 10;
const TIME_THRESHOLD_MS: i64 = 10;

pub fn beam_search_action(state: &MazeState) -> usize {
    let mut current_beam = BinaryHeap::new();
    let mut best_state: MazeState = state.clone();
    let time_keeper = TimeKeeper::new(TIME_THRESHOLD_MS);

    current_beam.push(state.clone());
    for depth in 0..BEAM_DEPTH {
        let mut next_beam = BinaryHeap::new();
        for _ in 0..BEAM_WIDTH {
            if time_keeper.is_time_over() {
                return best_state.first_action as usize;
            }
            let current_state = if let Some(s) = current_beam.pop() {
                s
            } else {
                break;
            };
            let legal_actions = current_state.legal_actions();
            for action in legal_actions {
                let mut next_state = current_state.clone();
                next_state.advance(action);
                next_state.evaluate_score();
                if depth == 0 {
                    next_state.first_action = action as isize;
                }
                next_beam.push(next_state);
            }
        }

        current_beam = next_beam;
        if current_beam.is_empty() {
            break;
        }

        if let Some(top_state) = current_beam.peek() {
            best_state = top_state.clone();
        }
        if best_state.is_done() {
            break;
        }
    }

    best_state.first_action as usize
}

次に、chokudaiサーチの実装です。

chokudai_search.rs
use crate::maze_state::MazeState;
use crate::time_keeper::TimeKeeper;
use std::collections::BinaryHeap;

const BEAM_WIDTH: usize = 1;
const BEAM_DEPTH: usize = 10;
const BEAM_NUMBER: usize = 4;
const TIME_THRESHOLD_MS: i64 = 10;

pub fn chokudai_search_action(state: &MazeState) -> usize {
    let time_keeper = TimeKeeper::new(TIME_THRESHOLD_MS);
    let mut beam: Vec<BinaryHeap<MazeState>> = vec![BinaryHeap::new(); BEAM_DEPTH + 1];
    beam[0].push(state.clone());

    for _ in 0..BEAM_NUMBER {
        if time_keeper.is_time_over() {
            break;
        }
        for depth in 0..BEAM_DEPTH {
            let (left, right) = beam.split_at_mut(depth + 1);
            let current_beam = &mut left[depth];
            let next_beam = &mut right[0];

            for _ in 0..BEAM_WIDTH {
                if current_beam.is_empty() {
                    break;
                }
                let current_state = current_beam.pop().unwrap();

                if current_state.is_done() {
                    break;
                }

                let legal_actions = current_state.legal_actions();
                for &action in &legal_actions {
                    let mut next_state = current_state.clone();
                    next_state.advance(action);
                    next_state.evaluate_score();
                    if depth == 0 {
                        next_state.first_action = action as isize;
                    }
                    next_beam.push(next_state);
                }
            }
        }
    }

    for depth in (0..=BEAM_DEPTH).rev() {
        if let Some(top_state) = beam[depth].peek() {
            return top_state.first_action as usize;
        }
    }
    state.legal_actions()[0]
}

最後に、各手法を複数回実行してスコアの平均を比較する main 関数です。

main.rs
use crate::maze_state::{greedy_action, play_game, random_action};
use crate::beam_search::beam_search_action;
use crate::chokudai_search::chokudai_search_action;

mod maze_state;
mod beam_search;
mod time_keeper;
mod chokudai_search;

fn main() {
    let trials = 100;
    let mut sum_random = 0;
    let mut sum_greedy = 0;
    let mut sum_beam = 0;
    let mut sum_chokudai = 0;

    let verbose = false;

    for i in 0..trials {
        let score = play_game(i as u64, random_action, verbose);
        sum_random += score;
    }
    let avg_random = sum_random as f64 / trials as f64;

    for i in 0..trials {
        let score = play_game(i as u64, greedy_action, verbose);
        sum_greedy += score;
    }
    let avg_greedy = sum_greedy as f64 / trials as f64;

    for i in 0..trials {
        let score = play_game(i as u64, beam_search_action, verbose);
        sum_beam += score;
    }
    let avg_beam = sum_beam as f64 / trials as f64;

    for i in 0..trials {
        let score = play_game(i as u64, chokudai_search_action, verbose);
        sum_chokudai += score;
    }
    let avg_chokudai = sum_chokudai as f64 / trials as f64;

    println!(
        "Random Action Average Score ({} trials) = {}",
        trials, avg_random
    );
    println!(
        "Greedy Action Average Score ({} trials) = {}",
        trials, avg_greedy
    );
    println!(
        "Beam Search Action Average Score ({} trials) = {}",
        trials, avg_beam
    );
    println!(
        "Chokudai Search Action Average Score ({} trials) = {}",
        trials, avg_chokudai
    );
}

実装のメモ

  • C++ では、メンバー変数に _ を付けるコーディングスタイルがあるようですが、Rust にはそのようなコーディングスタイルは一般的にはなさそうなので削除しています。
    • _ を付けるとしたら、変数の先頭に付けて、その変数が未使用であることをコードの読み手やコンパイラに明示するケースが主なものになるようです。
  • play_game 関数で、*_action をジェネリクス経由で呼び出したかったことと、半端な部分適用などを使いたくなかったので、ビーム幅、ビーム長などはグローバル定数として定義しています。

Next to do

  • AHC に備えて、関連性の高い4章、7章を同様の形で実装する予定です。
  • 5章、6章については CodinGame に参戦前に実装したいと考えていますが、時期は未定です。
脚注
  1. 推移過程に応じて評価値など結果が異なる ↩︎

  2. 図解もあり非常に分かりやすく、購入をお勧めします ↩︎

Discussion