Open11

Rustでコールスタックのヒープへの自動展開と再帰・相互再帰の自動メモ化を頑張ろうとする奮闘記

lumaluma

まずはPinの理解から

RustのPinチョットワカル - OPTiM TECH BLOG をベースに見ていきます。

ムーブが起こってポインタが同期できない例を引用
struct SelfRef {
    x: u32,
    // ptrは常にxを指していて欲しいが、SelfRefがムーブした瞬間に別のアドレスを指すようになる
    ptr: *const u32,
}

impl SelfRef {
    pub fn new(x: u32) -> SelfRef {
        let mut this = SelfRef {
            x,
            ptr: std::ptr::null(),
        };
        this.ptr = &this.x;

        // まだアドレスは変わらないのでテストは成功する
        assert_eq!(&this.x as *const _, this.ptr);

        // ここで値を返した瞬間にxのアドレスが変わり、ptrの値が不正となる
        this
    }
}

fn main() {
    let v = SelfRef::new(0);

    // v.xとv.ptrの値が異なるためテスト失敗
    assert_eq!(&v.x as *const _, v.ptr);
}

これはそもそもなぜポインタが動いてしまうのかというところですが、一応godboltで見てみます。以下のようにコードを変えておきます。

変更後のRustコード
struct SelfRef {
    x: u32,
    ptr: *const u32,
}

impl SelfRef {
    pub fn new(x: u32) -> SelfRef {
        let mut this = SelfRef {
            x,
            ptr: std::ptr::null(),
        };
        this.ptr = &this.x;

        println!("{}", (&this.x as *const _) == this.ptr);

        this
    }
}

pub fn main() {
    let v = SelfRef::new(0);

    println!("{}", (&v.x as *const _) == v.ptr);
}

これを -C opt-level=0 でコンパイルします。( https://rust.godbolt.org/z/ohsTojTbE で見れます。)

重要なところをかいつまんで書くと、 example::SelfRef::new: は以下のようになります。
(スラッシュコメントは私が追記)

example::SelfRef::new:
        sub     rsp, 136  // rspはスタックポインタ
        mov     dword ptr [rsp + 28], edi  // ediレジスタに引数xの値が格納されている、スタック[rsp + 28]に移されてる 
        mov     qword ptr [rsp + 128], 0
        mov     rdi, qword ptr [rsp + 128]
        call    qword ptr [rip + core::ptr::metadata::from_raw_parts@GOTPCREL]
        mov     qword ptr [rsp + 32], rax
        mov     rax, qword ptr [rsp + 32]
        mov     ecx, dword ptr [rsp + 28] // 引数xをecxにコピー
        mov     dword ptr [rsp + 48], ecx  // this変数のスタック領域上のxの部分が[rsp+48](32bit)
        // ↑ ecxは引数xの値が入っている。それをthis.xにコピーしている
        mov     qword ptr [rsp + 40], rax   // this変数のスタック領域上のptrの部分が[rsp+40](64bit)
        lea     rax, [rsp + 40]  //add     rax, 8  // 上記と併せて、rsp + 48 、つまりthis.xの位置をraxレジスタに格納
        ...(中略)
        mov     rax, qword ptr [rsp + 40]  // 返り値用のレジスタraxにthis.ptrを格納
        mov     edx, dword ptr [rsp + 48]  // 返り値用のレジスタedxにthis.xを格納
        add     rsp, 136
        ret    // 関数呼び出しをリターン

さて、とにかく example::SelfRef::new: は、this.xの値と、それをスタックにしまっていた時の、そのスタック上のポインタをthis.ptrとして返してしまっている。

ここでメイン関数を確認する。

example::main:
        sub     rsp, 104
        xor     edi, edi  // レジスタediを0にセット(これが引数)
        call    example::SelfRef::new  // 上記の呼び出し。
        mov     dword ptr [rsp + 24], edx  // レジスタedxにはthis.xが入っているが、それをスタック[rsp + 24]にコピー
        mov     qword ptr [rsp + 16], rax  // レジスタraxにはthis.ptrが入っているが、それをスタック[rsp + 16]にコピー

さて、raxに (rsp + 24) = this.xのポインタが入っているのを望んでいるという文脈だったが、実際には (rsp + 48 - 136) (注意: ret前に add rsp, 136 されている、rspはこの時点でのrspの値を指す)が入っている。
つまり、SelfRef::new:内でのかつてのスタック上のポインタを指してしまっている。
ムーブでもこれくらいのコピーは起きてしまい、不整合が生じる。

呟き: (正直Rustのムーブに対する誤解があったやぁ…(なのでここまで調べた…)。Copyトレイトがないとこれくらいのコピーも起きないものかとちょっと思ってた…)

ちなみに、Copy が実装されていない Vec<u32> でも同様のことが起きることも確認できた。

Boxを使ってヒープに移せば一旦はアサーションが通ります。

Boxを使ったアサーションが通る実装
struct SelfRef {
    x: Box<u32>,
    ptr: *const u32,
}

impl SelfRef {
    pub fn new(x: u32) -> SelfRef {
        let mut this = SelfRef {
            x: Box::new(x),
            ptr: std::ptr::null(),
        };
        this.ptr = &*this.x;

        assert_eq!(&*this.x as *const _, this.ptr);

        this
    }
}

fn main() {
    let v = SelfRef::new(0);

    assert_eq!(&*v.x as *const _, v.ptr);
}

上記の記事ではその後にBoxでもうまくいかないケースなどについて解説されていますね。
(あと、正確なところではBoxがポインタを維持することって保証されてると思っちゃってよくない可能性などありそう??)

lumaluma

現時点での、こう書ける用意にしたいという目標

rec!(
  is_odd = |v: u64| {
    if v == 0 { false } else { is_even(v - 1) }
  };
  is_even = |v: u64| {
    if v == 0 { true } else { is_odd(v - 1) }
  };
)
println!("{}, {}", is_odd(10),  is_even(10));

潜り込むように置換するだけであればproc_macroはなくてもいけるんじゃないかなという算段。
スタック展開やメモ化戦略の指定方法は後で考える。(まだ技術的制約がわからない)

lumaluma

Copyトレイトとはなんなのか

Pinを通してわかったCopyトレイトについて。

let some = Some::new();
let another = some;

上記のようなコードがあったとする。これは、もしSomeにCopyトレイトが実装されていなければ ムーブ が、されていれば コピー が起こる。
まず、いずれにしても、メモリレベルではコピーが起こる(可能性がある)。この時、そのメモリ上のコピーによって、所有権も移ってしまうならばRustとしては ムーブ であり、所有物が複製されるなら、それはRustとしては コピー とする(こともCopyトレイトの実装により可能)ということだ。

ドキュメントには以下のように書いてある。

Types whose values can be duplicated simply by copying bits.
source: https://doc.rust-lang.org/std/marker/trait.Copy.html

この duplicated というのが、所有物としての複製を意味する。(「所有物」、というのは所有権という言葉から勝手に作ったし、この言葉も文脈によりけりだと思うけれど。)

注意。ここまでは Some::newは Pin<..> は返さない、という前提。

ここで、Pin<T> であればさらに、Rust的な ムーブ では、メモリレベルでのコピー(とswap等による移動)が起きないようになる、ということだ(TにUnpinが実装されていなければ、ね)。

lumaluma

std::pin::Pin::new_uncheckedの使い方がsafeな例、unsafeな例をみてみる

std::pin::Pin::new_uncheckedのsafeな使用例

まず、RustのPinチョットワカル - OPTiM TECH BLOG でも紹介されている pin-utilspin_mut! マクロについて。

これはUnpinとは限らない T を安全に Pin<T> に変換してくれる。

コード例を抜粋すると以下のように紹介されている。

fn main() {
    let obj = NotUnpin::new();

    pin_mut!(obj);

    assert_pin::<NotUnpin>(&obj);

    obj.as_mut().method();
    obj.as_mut().method();
}

上記のpin_mut!を展開すると以下のようになる。

fn main() {
    let obj = NotUnpin::new();

    let mut obj = obj; // ①
    #[allow(unused_mut)]
    let mut obj = unsafe { std::pin::Pin::new_unchecked(&mut obj) }; // ②

    assert_pin::<NotUnpin>(&obj);

    obj.as_mut().method();
    obj.as_mut().method();
}

これがsafeなら一見すると std::pin::Pin::new_unchecked もsafeなのではないかと思いそうになるが、ここで重要なのは、②のシャドーイングで、①で宣言したpinする対象の変数に後からアクセスできなくなっていることが重要だ。

これは以下のようにunsafeな使用例を見ていくとさらに理解できる。

std::pin::Pin::new_uncheckedのunsafeな使用例

unsafeな例はstd::pin::Pin::new_uncheckedのドキュメントコメントに丁寧に記述されている。

fn move_pinned_ref<T>(mut a: T, mut b: T) {
    unsafe {
        let p: Pin<&mut T> = Pin::new_unchecked(&mut a);
        // This should mean the pointee `a` can never move again.
        // 意訳: pinするので今後aはムーブしちゃいけない
    }
    mem::swap(&mut a, &mut b);
    // The address of `a` changed to `b`'s stack slot, so `a` got moved even
    // though we have previously pinned it! We have violated the pinning API contract.
    // 意訳: aのアドレスがbのスタック上の位置に変わってしまった、つまりaがムーブされた。
    // 意訳: pinしたのにムーブしたので、pinのAPI契約を違反してしまった!
}

pinしたのに後から勝手にムーブしたらそれはunsafeになってしまう。

呟き: ところでこの書き方をされると、swapは言語レベルでメモリのスワップの方を行うように最適化されてるのかな、とか思ったが、そういうことではなかった(普通に中身を入れ替えていた)


unsafeな例はもう一つ載っている。そちらも確認すると良いかもしれない。

lumaluma

現時点での期待されるマクロ展開後 (コンパイルできます)

まだasync/awaitとスタック展開、メモ化等はなし。

コード
use std::cell::{Cell, RefCell};
use std::rc::{Rc, Weak};

fn main() {
    let is_odd = {
        struct Rec<'a> {
            f: RefCell<Option<&'a dyn Fn(u64) -> bool>>,
        }
        impl<'a> Rec<'a> {
            fn new() -> Self {
                Self {
                    f: RefCell::new(None),
                }
            }
            fn set(&self, f: &'a dyn Fn(u64) -> bool) {
                *self.f.borrow_mut() = Some(f);
            }
            fn call(&self, v: u64) -> bool {
                (*self.f.borrow()).unwrap()(v)
            }
        }
        Rc::new(Rec::new())
    };
    let is_even = {
        struct Rec<'a> {
            f: RefCell<Option<&'a dyn Fn(u64) -> bool>>,
        }
        impl<'a> Rec<'a> {
            fn new() -> Self {
                Self {
                    f: RefCell::new(None),
                }
            }
            fn set(&self, f: &'a dyn Fn(u64) -> bool) {
                *self.f.borrow_mut() = Some(f);
            }
            fn call(&self, v: u64) -> bool {
                (*self.f.borrow()).unwrap()(v)
            }
        }
        Rc::new(Rec::new())
    };

    let is_odd = (
        &|v: u64| -> bool {
            if v == 0 {
                false
            } else {
                Rc::clone(&is_even).call(v - 1)
            }
        },
        Rc::clone(&is_odd),
    );
    is_odd.1.set(is_odd.0);
    let is_odd = is_odd.1;

    let is_even = (
        &|v: u64| -> bool {
            if v == 0 {
                true
            } else {
                Rc::clone(&is_odd).call(v - 1)
            }
        },
        Rc::clone(&is_even),
    );
    is_even.1.set(is_even.0);
    let is_even = is_even.1;

    let is_odd = |v: u64| is_odd.call(v);
    let is_even = |v: u64| is_even.call(v);

    println!("{}, {}", is_odd(10), is_even(10));
}

ちなみにreference countするのはこの場合は完全に無駄で省けると思うんだけど、それをunsafeを駆使してどうやるとかはまだわからんね…

lumaluma

単一再起のFutureによる自動ヒープ上コールスタックの展開後の姿

展開後に期待するのは以下のようなコードです。単一の実行rsファイルとして動きます。

単一再起のFutureによる自動ヒープ上コールスタック
fn waker_do_nothing_vtable() -> &'static core::task::RawWakerVTable {
    unsafe fn clone_raw(data: *const ()) -> core::task::RawWaker {
        core::task::RawWaker::new(data, waker_do_nothing_vtable())
    }
    unsafe fn wake_raw(_data: *const ()) {}
    unsafe fn wake_by_ref_raw(_data: *const ()) {}
    unsafe fn drop_raw(_data: *const ()) {}
    &core::task::RawWakerVTable::new(clone_raw, wake_raw, wake_by_ref_raw, drop_raw)
}
fn new_waker_do_nothing() -> core::task::Waker {
    let raw_waker = core::task::RawWaker::new(core::ptr::null(), waker_do_nothing_vtable());
    unsafe { core::task::Waker::from_raw(raw_waker) }
}

struct PopperInner<T> {
    // TODO: Use bare type.
    waker: Option<core::task::Waker>,
    res: Option<T>,
}
#[repr(transparent)]
struct Popper<T> {
    // This realizes virtual pinning.
    // Internal status is only below, no need for real pinning.
    inner: *const Vec<PopperInner<T>>,
}
impl<T> core::future::Future for Popper<T> {
    type Output = T;
    fn poll(
        self: core::pin::Pin<&mut Self>,
        cx: &mut core::task::Context<'_>,
    ) -> core::task::Poll<T> {
        let inner_vec = unsafe { &mut *(self.inner as *mut Vec<PopperInner<T>>) };

        #[allow(clippy::cast_ref_to_mut)]
        let inner = unsafe {
            &mut *(inner_vec.last().unwrap_unchecked() as *const _ as *mut PopperInner<T>)
        };
        if let Some(res) = inner.res.take() {
            unsafe {
                debug_assert!(!inner_vec.is_empty());
                inner_vec.set_len(inner_vec.len() - 1);
                drop(core::ptr::read(inner_vec.as_ptr().add(inner_vec.len())));
            }

            core::task::Poll::Ready(res)
        } else {
            debug_assert!(inner.waker.is_none());
            inner.waker.replace(cx.waker().clone());
            core::task::Poll::Pending
        }
    }
}

fn main() {
    #[allow(clippy::cast_ref_to_mut)]
    let is_odd = {
        struct Rec<T0, F0> {
            f: T0,
            // TODO: Use bare type.
            arg: Option<u64>,
            popper_inner_stack: Vec<PopperInner<bool>>,
            // NOTE: Rust generated futures need actual pinning.
            future_stack: Vec<core::pin::Pin<Box<F0>>>,
        }
        impl<
                T0: Fn(*const dyn Fn(u64) -> Popper<bool>, u64) -> F0,
                F0: core::future::Future<Output = bool>,
            > Rec<T0, F0>
        {
            fn new(f: T0) -> Self {
                Self {
                    f,
                    arg: None,
                    popper_inner_stack: Vec::new(),
                    future_stack: Vec::new(),
                }
            }
            fn me(self: core::pin::Pin<&Self>) -> Box<dyn Fn(u64) -> Popper<bool>> {
                let this_arg_ptr = &self.arg as *const _;
                let this_popper_inner_stack_ptr = &self.popper_inner_stack as *const _;
                Box::new(move |v: u64| {
                    let this_arg = unsafe { &mut *(this_arg_ptr as *mut Option<u64>) };
                    let this_popper_inner_stack = unsafe {
                        &mut *(this_popper_inner_stack_ptr as *mut Vec<PopperInner<bool>>)
                    };
                    debug_assert!(this_arg.is_none());
                    this_arg.replace(v);
                    this_popper_inner_stack.push(PopperInner::<bool> {
                        waker: None,
                        res: None,
                    });
                    Popper::<bool> {
                        // SAFETY: pointer to el of Vec may be moved, but pointer to Vec for getting last element can be treated as pinned.
                        inner: this_popper_inner_stack as *const _,
                    }
                })
            }
            fn call(
                self: core::pin::Pin<&Self>,
                me: Box<dyn Fn(u64) -> Popper<bool>>,
                v: u64,
            ) -> bool {
                let me_ptr = &*me as *const _;
                let root_future = (self.f)(me_ptr, v);

                let this_future_stack = unsafe {
                    &mut *(&self.future_stack as *const _ as *mut Vec<core::pin::Pin<Box<F0>>>)
                };
                this_future_stack.push(Box::pin(root_future));

                let this_arg = unsafe { &mut *(&self.arg as *const _ as *mut Option<u64>) };
                let this_popper_inner_stack = unsafe {
                    &mut *(&self.popper_inner_stack as *const _ as *mut Vec<PopperInner<bool>>)
                };

                while let Some(top) = this_future_stack.last() {
                    // We know it'll be always waken in next loop.
                    let waker = new_waker_do_nothing();
                    let cx = &mut core::task::Context::from_waker(&waker);

                    if let core::task::Poll::Ready(r) =
                        unsafe { &mut *(top as *const _ as *mut core::pin::Pin<Box<F0>>) }
                            .as_mut()
                            .poll(cx)
                    {
                        debug_assert_eq!(
                            this_future_stack.len(),
                            this_popper_inner_stack.len() + 1
                        );

                        // pop unchecked
                        unsafe {
                            debug_assert!(!this_future_stack.is_empty());
                            this_future_stack.set_len(this_future_stack.len() - 1);
                            drop(core::ptr::read(
                                this_future_stack.as_ptr().add(this_future_stack.len()),
                            ));
                        }

                        if this_popper_inner_stack.is_empty() {
                            debug_assert!(this_future_stack.is_empty());
                            return r;
                        } else {
                            let top = unsafe {
                                &mut *(this_popper_inner_stack.last().unwrap_unchecked() as *const _
                                    as *mut PopperInner<bool>)
                            };

                            top.res = Some(r);
                            // It's waking Rust generated Futures (F0).
                            debug_assert!(top.waker.is_some());
                            unsafe { top.waker.take().unwrap_unchecked() }.wake();
                        }

                        continue;
                    }
                    debug_assert_eq!(this_future_stack.len(), this_popper_inner_stack.len());
                    let arg = unsafe { this_arg.take().unwrap_unchecked() };

                    let future = (self.f)(me_ptr, arg);
                    this_future_stack.push(Box::pin(future));
                }
                unsafe {
                    if cfg!(debug_assertions) {
                        unreachable!();
                    }
                    core::hint::unreachable_unchecked()
                }
            }
        }
        Rec::new(|is_odd, v: u64| async move {
            let is_odd = |v| unsafe { (*is_odd)(v) };
            if v == 0 {
                false
            } else {
                is_odd(v - 1).await
            }
        })
    };
    let is_odd = unsafe { core::pin::Pin::new_unchecked(&is_odd) };
    let is_odd = (is_odd,);
    let is_odd = (is_odd.0, |v: u64| is_odd.0.call(is_odd.0.me(), v));
    let is_odd = is_odd.1;

    println!("is_odd(1000000)={}", is_odd(1000000));
}

以下要点。

  • waker_do_nothing_vtable はwakeを要求されても何もしないwakerです。nullポインタで、いかなる要求に対しても基本的に何もしません。同期される状況下では次のループでwakeされることは保証されるので。
  • is_odd(1000000)の計算でデバッグビルド400ms、リリースビルド100msぐらいです。
    • 1000000という数は、ローカルでスタックを展開させない場合にスタックオーバーフローを起こす数でもあります。
  • 無駄な最後の詰め替えは相互再帰を想定しているものです。
  • Popper は Pin に強制的にしてますが、実際は Pin はされていません。ですが、これはあくまでも中間的な Future であり、かつその中身は自己の中にPinでないといけないステートを持たないので問題ありません。逆にinnerの、Vecの指す先はPinされていないといけないです。
  • 基本的には一瞬で全部作って、関数を出ないまま(同期的に)全部消費しきるから問題ない、という具合です。
  • F0には真のPinが必要です。 Vec<core::pin::Pin<Box<F0>>>Vec<F0> でやろうとすると移動してしまいますが、これは await から帰ってきた後に位置が変わってる可能性があるため、(末尾再帰でない限り) Pin されていなければいけません。
  • なのでコストかけて Box::pin しています。効率はアロケータ頼りです。
末尾再帰でなくする例 (is_oddという名前は完全無視)
            if v < 2 {
                false
            } else {
                !(is_odd(v - 1).await != is_odd(v - 2).await)
            }

相互再起への課題

このアプローチのままでは相互再起が微妙にできません。
まず、相互の呼び出し時にコールスタックを積んでは意味がないですから、 .call 相当ではなく、 me の受け渡しが必要です。しかし、これらが呼ばれた場合、それぞれの持ってるスタックに対してpollしなければいけません。回避するためには、スタックを全て共通にしなければなりませんが、dyn経由になるのでその分のオーバーヘッドは増えます。

上記にも書いたように、末尾再帰なら最適化できる、などのオプションは色々あるので追求したらキリがないわけですが、なんとなく単一再起は分岐したいと思ってしまったので、一旦これでマクロを実装してみようという次第です。

lumaluma

相互再帰なしだけど一旦完成した

相互再帰は最悪単一の再帰に(適当にフラグ用の変数使えば)帰着できるので、一旦これで。

  • 単一の関数での再帰
  • メモ化の戦略の指定
  • メモかしないことも選択可能
  • コールスタックのOSレベルのスタックからヒープ上のスタックへの展開
  • 宣言的マクロ ( macro_rules! ) のみ利用
  • 相互再帰(を簡単に記述できる手法の提供)

この現状でまずは記事を書く。

編集: 書いたのが以下。

https://zenn.dev/luma/articles/rust-auto-call-stack-in-heap