🐙

Rustでasync/await文を利用して自動でコールスタックをヒープに展開しメモ化もする (プロコン・競プロ向け)

2022/09/18に公開

モチベーション

プログラミングコンテストや競技プログラミングといった場面では、たびたびコールスタックをスタックでやらない必要が出てきます。
コールスタックというのは、関数呼び出し時にメモリ上のスタック領域に呼び出し前のローカル変数を積み上げていったりする場所なのですが、これは一般にヒープ領域よりも利用できる領域が限られています。
使いすぎると、いわゆるスタックオーバーフローと呼ばれる状態になってしまいます。体感的には 10^5 回から 10^7 回の間ぐらいで、スタフロの可能性が上がる感じがします(環境にもよりますし、これはあくまでも感覚値ですが)。言語によってはオーバーヘッドにより、この数はさらに小さくなることかと思われます。

そのほかにも、以下のようなモチベーションがあります。

  • 簡単に再帰関数を記述できるマクロが欲しい
  • 簡単にメモ化を記述できるマクロが欲しい
  • 競プロでも使いたいので、手続的マクロは無しで、宣言的マクロ (macro_rules!)のみで実現したい
  • 簡単にローカル変数をキャプチャしたい

コールスタックを展開するとは

例として三項間漸化式 a_{n+2} = a_{n} + a_{n+1}、初項 a_0 = 0, a_1 = 1 (フィボナッチ数列) を求める関数 fib を考えてみます。(簡単のため、107で割ります)

fn fib(n: u64) -> u64 {
    match n {
        0 => 0,
        1 => 1,
        n => (fib(n - 2) + fib(n - 1)) % 107,
    }
}

fn main() {
    assert_eq!(fib(20), 6765 % 107);
    // assert_eq!(fib(40), 102334155 % 107); // コメントアウトする場合は気を付けてください
}

これは2つの問題を抱えています。1つは同じ引数に対して二度以上計算してしまう、もう一つはコールスタックを利用してしまうという点です。

同じ引数に対して二度以上計算してしまう、という問題はメモ化することで対処、コールスタックについては計算過程のステートをスタックで自分で管理することで対処できます。

以下に愚直に実装する場合の例を載せておきます。

メモ化をする例
use std::collections::HashMap;

fn fib(n: u64, memo: &mut HashMap<u64, u64>) -> u64 {
    if let Some(r) = memo.get(&n) {
        return *r;
    }
    let r = match n {
        0 => 0,
        1 => 1,
        n => (fib(n - 2, memo) + fib(n - 1, memo)) % 107,
    };
    memo.insert(n, r);
    r
}

fn main() {
    let mut memo = HashMap::new();
    assert_eq!(fib(20, &mut memo), 6765 % 107);
    assert_eq!(fib(40, &mut memo), 102334155 % 107);
}
自分でコールスタックを管理する例
use core::cell::RefCell;
use std::rc::Rc;

fn fib(n: u64) -> u64 {
    enum State {
        Start {
            arg: u64,
            ret: Option<Rc<RefCell<u64>>>,
        },
        Wait {
            ret: Option<Rc<RefCell<u64>>>,
            out1: Rc<RefCell<u64>>,
            out2: Rc<RefCell<u64>>,
        },
        Output {
            ret: Option<Rc<RefCell<u64>>>,
            out: u64,
        },
    }
    let mut stack = vec![State::Start { arg: n, ret: None }];

    while let Some(top) = stack.pop() {
        match top {
            State::Start { arg, ret } => match arg {
                0 | 1 => {
                    stack.push(State::Output { out: arg, ret });
                }
                _ => {
                    let out1 = Rc::new(RefCell::new(Default::default()));
                    let out2 = Rc::new(RefCell::new(Default::default()));
                    stack.push(State::Wait {
                        ret,
                        out1: out1.clone(),
                        out2: out2.clone(),
                    });
                    stack.push(State::Start {
                        arg: arg - 2,
                        ret: Some(out1),
                    });
                    stack.push(State::Start {
                        arg: arg - 1,
                        ret: Some(out2),
                    });
                }
            },
            State::Wait { ret, out1, out2 } => {
                stack.push(State::Output {
                    ret,
                    out: (*out1.borrow_mut() + *out2.borrow_mut()) % 107,
                });
            }
            State::Output { out, ret } => match ret {
                Some(ret) => {
                    *ret.borrow_mut() = out;
                }
                None => return out,
            },
        }
    }
    unreachable!();
}

fn main() {
    assert_eq!(fib(20), 6765 % 107);
    // assert_eq!(fib(40), 102334155 % 107); // コメントアウトする場合は気を付けてください
}

async/awaitの利用

今回私が目をつけたのは、async/await構文です。
tokioチュートリアルのAsync in depthをみたときに、Rustのasync/awaitはスレッドモデルの定義すら自分でできてしまうということがわかり、これはかなり自由に使えるのではないか、と思っていました。

そこで今回、まさにasync/awaitが使えたわけですが、これはRustがasyncブロックに対して生成する Future がまさにステートマシンである、というようなところに起因します。

細かい学習の軌跡はスクラップの方にも載せていますので、こちらも参考になるかもしれません。こちらも引き続き、追記していこうと思います。

rec!マクロ

最終的には以下のような構文に落ち着きました。

fn main() {
    rec! {
        async fn fib(n: u64) -> u64 {
            match n {
                0 => 0,
                1 => 1,
                n => (fib(n - 2).await + fib(n - 1).await) % 107,
            }
        }
    }

    assert_eq!(fib(20), 6765 % 107);
    // assert_eq!(fib(40), 102334155 % 107); // コメントアウトする場合は気を付けてください
}

デフォルトではメモ化をしない、という戦略をとるようになっており、Memoトレイトを実装した構造体を指定することができます。

use std::collections::HashMap;
fn main() {
    rec! {
        #[memo(HashMap::<_, _>::new())]
        async fn fib(n: u64) -> u64 {
            match n {
                0 => 0,
                1 => 1,
                n => (fib(n - 2).await + fib(n - 1).await) % 107,
            }
        }
    }

    assert_eq!(fib(20), 6765 % 107);
    assert_eq!(fib(40), 102334155 % 107);
    assert_eq!(fib(1000000), 86);
}

以下が細かい仕様になります。

細かい仕様
  • #[memo(<expr>)] については実際にattributeとして処理されているわけではなく、パターンマッチで拾っているだけです。
  • 他の属性は今のところ使えません。
  • 外側の変数は不変借用のみ可能です。(特に何もせずにキャプチャできます)
  • async/awaitは必須です。
  • 引数は任意個指定可能です。
  • トレイト境界の指定 (<...>where 句) や、impl ... 構文の、引数や返り値の型への利用はできません。
  • 引数の型は省略不可能です。
  • 返り値は省略可能です。
  • 複数個の関数を一つのrec!マクロで列挙することは(現時点では)不可能です。
  • 外からの呼び出しの際には await はつけません。
  • 変数名はすべて、識別子でなければいけません。(パターンマッチは使えません。 _ 単体もパターンであり、使えないので注意)
  • async/await構文を利用していますが、実行はシングルスレッドで完全に同期的に行われます。
    • 「非同期用構文だから無駄な待ちが発生している」等はありません。
  • 関数名がローカル変数として定義され、スコープを出ると使えなくなります。
  • 関数の本体は完全に自由で、関数名以外の使えない変数・型がある(勝手にシャドーイングされる)といったことも一切ありません。
    • 頑張って関数名の変数一つを使いまわして意図しないシャドーイングを回避しています。
  • 引数は全て Clone + 'static、返り値は 'static を満たす必要があります。
    • 引数は大抵 Copy まで満たしていると思っているので、問題ないかと思っています。
    • メモ化する場合は、(少なくともデフォルトで用意した方法は)返り値に対して追加で Clone を要求します。
  • メモ化する場合は、それぞれ個別で追加の要求があるものがあります。ハッシュマップを使う場合は引数全てに対して Eq + Hash、など。
  • セット系のメモは返り値なし (()) の場合にのみ実装されています。
  • #[memo(型)] はよく使いそうなものへのショートカットを用意しています。 (本質ではない)
    • #[memo_vec]: Vec<_>::new() のショートハンド。なお、ショートハンドなしで Vec 単体から推論させることはできません。以下同様。
    • #[memo_map], #[memo_hashmap]: HashMap<_, _>::new() のショートハンド。
    • #[memo_btreemap]: BTreeMap<_, _>::new() のショートハンド。
    • #[memo_set], #[memo_hashset]: HashSet<_, _>::new() のショートハンド。
    • #[memo_btreeset]: BTreeSet<_, _>::new() のショートハンド。
  • unsafe をたくさん使っています。いくつかは、Future が非同期的な状況を想定しているのに対し、同期的に利用することにおけるものです。そのほかは、静的に安全が保証可能な RefCell を高速化する形で利用されているもの、などです。

マクロの定義全体

マクロだけではなく、構造体の定義等もすべて入っています。マクロは rec! のみです。

マクロの定義全体
macro_rules! rec {
    (
        #[memo_vec]
        $($tt:tt)*
    ) => {
        rec! {
            #[memo(::std::vec::Vec::<_>::new())]
            $($tt)*
        }
    };
    (
        #[memo_map]
        $($tt:tt)*
    ) => {
        rec! {
            #[memo(::std::collections::HashMap::<_, _>::new())]
            $($tt)*
        }
    };
    (
        #[memo_hashmap]
        $($tt:tt)*
    ) => {
        rec! {
            #[memo(::std::collections::HashMap::<_, _>::new())]
            $($tt)*
        }
    };
    (
        #[memo_btreemap]
        $($tt:tt)*
    ) => {
        rec! {
            #[memo(::std::collections::BTreeMap::<_, _>::new())]
            $($tt)*
        }
    };
    (
        #[memo_set]
        $($tt:tt)*
    ) => {
        rec! {
            #[memo(::std::collections::HashSet::<_, _>::new())]
            $($tt)*
        }
    };
    (
        #[memo_hashset]
        $($tt:tt)*
    ) => {
        rec! {
            #[memo(::std::collections::HashSet::<_, _>::new())]
            $($tt)*
        }
    };
    (
        #[memo_btreeset]
        $($tt:tt)*
    ) => {
        rec! {
            #[memo(::std::collections::BTreeSet::<_, _>::new())]
            $($tt)*
        }
    };
    (
        #[memo($memo:expr)]
        async fn $name:ident($($arg:ident : $arg_type:ty),*) -> $ret_type:ty {
            $($body:tt)*
        }
    ) => {
        let $name = {
            #[allow(unused_parens)]
            $crate::Rec::<_, ($($arg_type),*), $ret_type, _, _>::new(|$name, ($($arg),*): ($($arg_type),*)| {
                let $name = $crate::ForceMover(($name, ($($arg),*)));
                #[warn(unused_parens)]
                async {
                    let $name = $name;
                    let $name = $name.0;
                    #[allow(unused_parens)]
                    let ($name, ($($arg),*)) = $name;
                    #[allow(unused_variables)]
                    let $name = |$($arg:$arg_type),*| unsafe { (*$name)(($($arg),*)) };
                    $($body)*
                }
            }, $memo)
        };
        let $name = unsafe { ::core::pin::Pin::new_unchecked(&$name) };
        let $name = ($name,);
        let $name = ($name.0, |$($arg : $arg_type),*| unsafe { $name.0.call($name.0.me(), ($($arg),*)) });
        let $name = $name.1;
    };

    (
        #[memo($($memo:tt)*)]
        async fn $name:ident($($arg:ident : $arg_type:ty),*)  {
            $($body:tt)*
        }
    ) => {
        rec! {
            #[memo($($memo)*)]
            async fn $name($($arg : $arg_type),*) -> () {
                $($body)*
            }
        }
    };

    (
        async fn $name:ident($($arg:ident : $arg_type:ty),*) $(-> $ret_type:ty)? {
            $($body:tt)*
        }
    ) => {
        rec! {
            #[memo($crate::NoMemo {})]
            async fn $name($($arg : $arg_type),*) $(-> $ret_type)? {
                $($body)*
            }
        }
    };
}

fn waker_do_nothing_vtable() -> &'static ::core::task::RawWakerVTable {
    unsafe fn clone_raw(data: *const ()) -> ::core::task::RawWaker {
        ::core::task::RawWaker::new(data, waker_do_nothing_vtable())
    }
    unsafe fn wake_raw(_data: *const ()) {}
    unsafe fn wake_by_ref_raw(_data: *const ()) {}
    unsafe fn drop_raw(_data: *const ()) {}
    &::core::task::RawWakerVTable::new(clone_raw, wake_raw, wake_by_ref_raw, drop_raw)
}
fn new_waker_do_nothing() -> ::core::task::Waker {
    let raw_waker = ::core::task::RawWaker::new(::core::ptr::null(), waker_do_nothing_vtable());
    unsafe { ::core::task::Waker::from_raw(raw_waker) }
}

struct PopperInner<Args, T> {
    // TODO: Use bare type.
    waker: ::core::option::Option<::core::task::Waker>,
    res: ::core::option::Option<T>,
    args: ::core::option::Option<Args>,
}
// Generally unsafe. Only usage from Rec is safe.
struct Popper<Args, T> {
    // This realizes virtual pinning.
    // Internal status is only below, no need for real pinning.
    inner: *const ::std::vec::Vec<PopperInner<Args, T>>,
}
impl<Args, T> ::core::future::Future for Popper<Args, T> {
    type Output = T;
    fn poll(
        self: ::core::pin::Pin<&mut Self>,
        cx: &mut ::core::task::Context<'_>,
    ) -> ::core::task::Poll<T> {
        let inner_vec = unsafe { &mut *(self.inner as *mut ::std::vec::Vec<PopperInner<Args, T>>) };

        #[allow(clippy::cast_ref_to_mut)]
        let inner = unsafe {
            &mut *(inner_vec.last().unwrap_unchecked() as *const _ as *mut PopperInner<Args, T>)
        };
        if let ::core::option::Option::Some(res) = inner.res.take() {
            // unchecked pop
            unsafe {
                debug_assert!(!inner_vec.is_empty());
                inner_vec.set_len(inner_vec.len() - 1);
                drop(::core::ptr::read(inner_vec.as_ptr().add(inner_vec.len())));
            }

            ::core::task::Poll::Ready(res)
        } else {
            debug_assert!(inner.waker.is_none());
            inner.waker.replace(cx.waker().clone());
            ::core::task::Poll::Pending
        }
    }
}

struct Rec<M: Memo<Args, Output> + 'static, Args, Output, T0, F0> {
    f: T0,
    arg: ::core::option::Option<Args>,
    popper_inner_stack: ::std::vec::Vec<PopperInner<Args, Output>>,
    // NOTE: Rust generated futures need actual pinning.
    future_stack: ::std::vec::Vec<::core::pin::Pin<Box<F0>>>,
    memo: M,
}

impl<
        M: Memo<Args, Output>,
        Args: Clone + 'static,
        Output: 'static,
        T0: Fn(*const dyn Fn(Args) -> Popper<Args, Output>, Args) -> F0,
        F0: ::core::future::Future<Output = Output>,
    > Rec<M, Args, Output, T0, F0>
{
    fn new(f: T0, memo: M) -> Self {
        Self {
            f,
            arg: ::core::option::Option::None,
            popper_inner_stack: ::std::vec::Vec::new(),
            future_stack: ::std::vec::Vec::new(),
            memo,
        }
    }

    #[allow(clippy::cast_ref_to_mut)]
    unsafe fn me(self: ::core::pin::Pin<&Self>) -> Box<dyn Fn(Args) -> Popper<Args, Output>> {
        let this_arg_ptr = &self.arg as *const _;
        let this_popper_inner_stack_ptr = &self.popper_inner_stack as *const _;
        let this_memo_ptr = &self.memo as *const _;
        Box::new(move |args: Args| {
            let this_arg = &mut *(this_arg_ptr as *mut ::core::option::Option<Args>);
            let this_popper_inner_stack = &mut *(this_popper_inner_stack_ptr
                as *mut ::std::vec::Vec<PopperInner<Args, Output>>);
            let this_memo = &mut *(this_memo_ptr as *const _ as *mut M);
            let res = Memo::get_memo(this_memo, &args);
            if res.is_none() {
                debug_assert!(this_arg.is_none());
                this_arg.replace(Clone::clone(&args));
            }
            this_popper_inner_stack.push(PopperInner::<Args, Output> {
                waker: ::core::option::Option::None,
                res,
                args: ::core::option::Option::Some(Clone::clone(&args)),
            });
            Popper {
                // SAFETY: pointer to el of Vec may be moved, but pointer to Vec for getting last element can be treated as pinned.
                inner: this_popper_inner_stack as *const _,
            }
        })
    }

    #[allow(clippy::cast_ref_to_mut)]
    unsafe fn call(
        self: ::core::pin::Pin<&Self>,
        me: Box<dyn Fn(Args) -> Popper<Args, Output>>,
        args: Args,
    ) -> Output {
        if let ::core::option::Option::Some(e) = Memo::get_memo(&self.memo, &args) {
            return e;
        }

        let me_ptr = &*me as *const _;
        let root_future = (self.f)(me_ptr, Clone::clone(&args));

        let this_future_stack = &mut *(&self.future_stack as *const _
            as *mut ::std::vec::Vec<::core::pin::Pin<Box<F0>>>);
        this_future_stack.push(Box::pin(root_future));

        let this_arg = &mut *(&self.arg as *const _ as *mut ::core::option::Option<Args>);
        let this_popper_inner_stack = &mut *(&self.popper_inner_stack as *const _
            as *mut ::std::vec::Vec<PopperInner<Args, Output>>);

        while let ::core::option::Option::Some(top) = this_future_stack.last() {
            // We know it'll be always waken in next loop.
            let waker = new_waker_do_nothing();
            let cx = &mut ::core::task::Context::from_waker(&waker);

            if let ::core::task::Poll::Ready(r) =
                { &mut *(top as *const _ as *mut ::core::pin::Pin<Box<F0>>) }
                    .as_mut()
                    .poll(cx)
            {
                debug_assert_eq!(this_future_stack.len(), this_popper_inner_stack.len() + 1);

                // pop unchecked
                debug_assert!(!this_future_stack.is_empty());
                this_future_stack.set_len(this_future_stack.len() - 1);
                drop(::core::ptr::read(
                    this_future_stack.as_ptr().add(this_future_stack.len()),
                ));

                if this_popper_inner_stack.is_empty() {
                    debug_assert!(this_future_stack.is_empty());
                    let this_memo = &mut *(&self.memo as *const _ as *mut M);
                    Memo::insert_memo(this_memo, Clone::clone(&args), &r);
                    return r;
                } else {
                    let top = {
                        &mut *(this_popper_inner_stack.last().unwrap_unchecked() as *const _
                            as *mut PopperInner<Args, Output>)
                    };

                    let this_memo = &mut *(&self.memo as *const _ as *mut M);
                    Memo::insert_memo(this_memo, top.args.take().unwrap_unchecked(), &r);
                    top.res = ::core::option::Option::Some(r);
                    // It's waking Rust generated Futures (F0).
                    debug_assert!(top.waker.is_some());
                    top.waker.take().unwrap_unchecked().wake();
                }

                continue;
            }
            debug_assert_eq!(this_future_stack.len(), this_popper_inner_stack.len());
            let arg = this_arg.take().unwrap_unchecked();

            let future = (self.f)(me_ptr, arg);
            this_future_stack.push(Box::pin(future));
        }
        if cfg!(debug_assertions) {
            unreachable!();
        }
        ::core::hint::unreachable_unchecked()
    }
}

/// Important difference to tuple is there is no Copy trait on this even if T is so.
struct ForceMover<T>(T);

trait Memo<Args, Output> {
    fn get_memo(&self, args: &Args) -> ::core::option::Option<Output>;
    fn insert_memo(&mut self, args: Args, output: &Output);
}

impl<Args: ::std::cmp::Eq + ::std::hash::Hash, Output: Clone> Memo<Args, Output>
    for ::std::collections::HashMap<Args, Output>
{
    fn get_memo(&self, args: &Args) -> ::core::option::Option<Output> {
        self.get(args).cloned()
    }
    fn insert_memo(&mut self, args: Args, output: &Output) {
        self.insert(args, output.clone());
    }
}

impl<Args: ::std::cmp::Ord, Output: Clone> Memo<Args, Output>
    for ::std::collections::BTreeMap<Args, Output>
{
    fn get_memo(&self, args: &Args) -> ::core::option::Option<Output> {
        self.get(args).cloned()
    }
    fn insert_memo(&mut self, args: Args, output: &Output) {
        self.insert(args, output.clone());
    }
}

impl<Args: ::std::cmp::Eq + ::std::hash::Hash> Memo<Args, ()>
    for ::std::collections::HashSet<Args>
{
    fn get_memo(&self, args: &Args) -> ::core::option::Option<()> {
        if self.contains(args) {
            ::core::option::Option::Some(())
        } else {
            ::core::option::Option::None
        }
    }
    fn insert_memo(&mut self, args: Args, _output: &()) {
        self.insert(args);
    }
}

impl<Args: ::std::cmp::Ord> Memo<Args, ()> for ::std::collections::BTreeSet<Args> {
    fn get_memo(&self, args: &Args) -> ::core::option::Option<()> {
        if self.contains(args) {
            ::core::option::Option::Some(())
        } else {
            ::core::option::Option::None
        }
    }
    fn insert_memo(&mut self, args: Args, _output: &()) {
        self.insert(args);
    }
}

trait SerializeAsUsize {
    fn serialize_as_usize(self) -> usize;
}

impl SerializeAsUsize for usize {
    fn serialize_as_usize(self) -> usize {
        self
    }
}

impl SerializeAsUsize for u8 {
    fn serialize_as_usize(self) -> usize {
        self as usize
    }
}

impl SerializeAsUsize for u16 {
    fn serialize_as_usize(self) -> usize {
        self as usize
    }
}

impl SerializeAsUsize for u32 {
    fn serialize_as_usize(self) -> usize {
        self as usize
    }
}

impl SerializeAsUsize for u64 {
    fn serialize_as_usize(self) -> usize {
        self as usize
    }
}

impl SerializeAsUsize for u128 {
    fn serialize_as_usize(self) -> usize {
        self as usize
    }
}

impl SerializeAsUsize for bool {
    fn serialize_as_usize(self) -> usize {
        self as usize
    }
}

impl<Args: SerializeAsUsize + Copy, Output: Clone> Memo<Args, Output>
    for ::std::vec::Vec<::core::option::Option<Output>>
{
    fn get_memo(&self, args: &Args) -> ::core::option::Option<Output> {
        match self.get(SerializeAsUsize::serialize_as_usize(*args)) {
            ::core::option::Option::Some(::core::option::Option::Some(inner)) => {
                ::core::option::Option::Some(inner.clone())
            }
            _ => ::core::option::Option::None,
        }
    }
    fn insert_memo(&mut self, args: Args, output: &Output) {
        let index = SerializeAsUsize::serialize_as_usize(args);
        while self.len() <= index {
            self.push(::core::option::Option::None);
        }
        self[index] = ::core::option::Option::Some(output.clone());
    }
}

struct NoMemo {}

impl<Args, Output> Memo<Args, Output> for NoMemo {
    fn get_memo(&self, _args: &Args) -> ::core::option::Option<Output> {
        ::core::option::Option::None
    }
    fn insert_memo(&mut self, _args: Args, _output: &Output) {}
}

また、執筆時点ではAtCoderのRustのバージョンは1.42.0なので、以下のようなポリフィルが必要です。

Optionのポリフィル
pub trait OptionPolyfill<T> {
    #[allow(clippy::missing_safety_doc)]
    unsafe fn unwrap_unchecked(self) -> T;
}
impl<T> OptionPolyfill<T> for ::core::option::Option<T> {
    unsafe fn unwrap_unchecked(self) -> T {
        debug_assert!(self.is_some());
        match self {
            ::core::option::Option::Some(val) => val,
            // SAFETY: the safety contract must be upheld by the caller.
            ::core::option::Option::None => ::std::hint::unreachable_unchecked(),
        }
    }
}
  • AtCoderでの使用例(ネタバレ注意): 提出 #34970260
    • 最速帯が2ms程度で、この提出が8ms。

強化したいこととか

かなりよくできたというふうに思うので、以下ぐらいが課題かなと思っています(技術的に無理だとわかった要望が除かれています)。

  • 自動で相互再帰
  • #[allow(unused)] などの属性の引き継ぎ
  • 末尾再帰での最適化(検出はせずに属性でunsafe前提で指定する)
  • メモ化だけする方法の提供
  • LRUキャッシュでのメモ化とか (実用方面)

また、スクラップの方では、諦めた仕様など(自動でawaitするとか)について、解説と共に載せていこうかと思っています。

その他の使用例

fact_manual_memo: 外から不変借用して、メモを自分で管理して階乗を求める例
use core::cell::RefCell;
use std::collections::HashMap;
fn main() {
    let memo = RefCell::new(HashMap::<u64, u64>::new());
    rec! {
        async fn fact_manual_memo(x: u64) -> u64 {
            if let Option::Some(e) = memo.borrow_mut().get(&x) {
                return *e;
            }
            let r = {
                if x == 0 {
                    1
                } else {
                    x * fact_manual_memo(x - 1).await % 1000000007
                }
            };
            memo.borrow_mut().insert(x, r);
            r
        }
    }

    // 457992974
    println!("fact_manual_memo(100000)={}", fact_manual_memo(100000));
}
fact: シンプルに階乗を求め、メモ化する例
fn main() {
    rec! {
        #[memo_map]
        async fn fact(x: usize) -> u64 {
            if x == 0 {
                1
            } else {
                x as u64 * fact(x - 1).await % 1000000007
            }
        }
    }

    // 457992974
    println!("fact(100000)={}", fact(100000));
    for _ in 0..100000 {
        assert_eq!(fact(100000), 457992974);
    }
}
collatz: コラッツ問題のステップの経過をプリント、一度訪れたら表示しない、セットでメモ化する例
fn main() {
    rec! {
        // #[memo_vec]
        // #[memo_map]
        #[memo_set]
        async fn collatz(x: u64) {
            println!("enter: {}", x);
            if x <= 1 {
                return;
            }
            if x % 2 == 0 {
                collatz(x / 2).await;
            } else {
                collatz(x * 3 + 1).await;
            }
        }
    }

    collatz(200);
    collatz(30);
    collatz(90);
}
gcd: メモ化せずにシンプルにGCDを求める例
fn main() {
    rec! {
        async fn gcd(a: u64, b: u64) -> u64 {
            if b == 0 {
                a
            } else {
                gcd(b, a % b).await
            }
        }
    }

    // 1
    println!("gcd(1893, 1742)={}", gcd(1893, 1742));
}
is_odd, is_even: 同じメモを使い回す例
use std::cell::RefCell;
use std::rc::Rc;

impl<Args, Output, T: Memo<Args, Output>> Memo<Args, Output> for Rc<RefCell<T>> {
    fn get_memo(&self, args: &Args) -> Option<Output> {
        self.borrow_mut().get_memo(args)
    }
    fn insert_memo(&mut self, args: Args, output: &Output) {
        self.borrow_mut().insert_memo(args, output)
    }
}

struct BiasedMemo<T> {
    base: T,
    bias: u64,
}
impl<Output, T: Memo<u64, Output>> Memo<u64, Output> for BiasedMemo<T> {
    fn get_memo(&self, args: &u64) -> Option<Output> {
        self.base.get_memo(&(args + self.bias))
    }
    fn insert_memo(&mut self, args: u64, output: &Output) {
        self.base.insert_memo(args + self.bias, output)
    }
}

fn main() {
    let memo = Rc::new(RefCell::new(Vec::new()));
    rec! {
        #[memo(Rc::clone(&memo))]
        async fn is_odd(x: u64) -> bool {
            println!("enter: is_odd({})", x);
            if x == 0 {
                false
            } else {
                !is_odd(x - 1).await
            }
        }
    }

    let memo1 = BiasedMemo {
        base: Rc::clone(&memo),
        bias: 1,
    };
    rec! {
        #[memo(memo1)]
        async fn is_even(x: u64) -> bool {
            println!("enter: is_even({})", x);
            if x == 0 {
                true
            } else {
                !is_even(x - 1).await
            }
        }
    }

    println!("is_odd(4)={}", is_odd(4));
    println!("is_even(4)={}", is_even(4));
}

構造体 ForceMover がなぜ必要なのか、に関する解説

https://zenn.dev/luma/articles/rust-why-and-how-force-move-copy-trait

更新履歴

  • 2022/09/18 01時頃: 初版公開
  • 2022/09/18 16時頃: 以下の修正
    • ::std の形の fully qualified に。 (Vecを自分のに置き換えるとかはできなくなるので、必要であれば適宜書き換えて使ってください。)
    • 変数の名前空間を消費しないと言いつつ、args を使ってしまっていたので、使わないように修正。
    • signed int の vector によるメモ化は 128 + (i8の変数) のようにしていたが、 0, 1, -1, 2, -2, \cdots のように番号付けしたいこともあるだろうと思い、実装を用意しないことにした。
    • Rec のメソッドの一部を unsafe 指定しました。 ( pub 指定じゃないから(仮にモジュールだとしても)別にいいかなと思いつつ)
    • 「細かい仕様」の追記。
    • #その他の使用例セクションを追加しました。
    • メモ化の指定方法を変えました。型ではなく、初期化用のをそのまま渡します。
      • 例えば、メモを複数の関数で使いまわしたい場合は #memo(Rc::clone(&memo)) などが指定できます。 #その他の使用例 の「同じメモを使い回す例」に例を載せています。
    • その他、細かい修正。
  • 2022/09/18 17時頃: #[memo(vec)] 等のショートハンドを #[memo_vec] などのように変更。

ライセンス

プロコン等ではご自由にスニペットとしてお使いください。リンク等貼っていただける場合は、この記事へのリンクで構いません。必須ではありません。

クレート化する等の場合は一声いただけると幸いです。 私としては、より実践的な場面では明示的にメモ化の処理を書くべきかなと思っており、あまり濫用されるべきではないのかなと思っています。しかし、コールスタックの展開に関しては有用かもしれません。要望等がありましたら、検討したいと思います。

GitHubで編集を提案

Discussion