Closed6

Vec::dedup_by の最適化

yukiyuki

10%くらい速度向上した模様。ptr::copy_nonoverlapping ではなく、ptr::copy を使用するようにした修正らしい。

Vec::extend のときと同様に、drop する際に最適化をかける形式をとっているように見える。最終的には mem::forget する感じかな。あとで詳細はまとめる。

先日話題になっていた Rust のパターン集(の案)でいうところの panic guard とかに該当しそうな感じ。
https://gist.github.com/qnighy/be99c2ece6f3f4b1248608a04e104b38

yukiyuki

修正前は、Slicepartition_dedup_by というメソッドと、その結果取得できる新しいベクタの要素数を保持しておき、最後にその長さに truncate するというロジックをとっていた。truncate は中で不要になった領域を drop するだけなので、partition_dedup_by に少し問題があったと考えるのがよさそう。

yukiyuki

全然関係ないのだけど、dedup ってたしか重複排除のことだったよな…?と思って、記事を漁っていたらこんな記事を見つけた。

https://qiita.com/yagince/items/73184237964e9dbb8b3d

余計なことをせずに全部 clone したほうが結果的には速いという調査結果まで載っており、こういうどのタイミングで clone すると遅くなるかって難しいよな〜、ちゃんとベンチを地道に取るしかないよね〜、と思うなど。

yukiyuki

partition_dedup_by のアルゴリズムについては、下記コメントに記載されている。

        // Although we have a mutable reference to `self`, we cannot make
        // *arbitrary* changes. The `same_bucket` calls could panic, so we
        // must ensure that the slice is in a valid state at all times.
        //
        // The way that we handle this is by using swaps; we iterate
        // over all the elements, swapping as we go so that at the end
        // the elements we wish to keep are in the front, and those we
        // wish to reject are at the back. We can then split the slice.
        // This operation is still `O(n)`.
        //
        // Example: We start in this state, where `r` represents "next
        // read" and `w` represents "next_write`.
        //
        //           r
        //     +---+---+---+---+---+---+
        //     | 0 | 1 | 1 | 2 | 3 | 3 |
        //     +---+---+---+---+---+---+
        //           w
        //
        // Comparing self[r] against self[w-1], this is not a duplicate, so
        // we swap self[r] and self[w] (no effect as r==w) and then increment both
        // r and w, leaving us with:
        //
        //               r
        //     +---+---+---+---+---+---+
        //     | 0 | 1 | 1 | 2 | 3 | 3 |
        //     +---+---+---+---+---+---+
        //               w
        //
        // Comparing self[r] against self[w-1], this value is a duplicate,
        // so we increment `r` but leave everything else unchanged:
        //
        //                   r
        //     +---+---+---+---+---+---+
        //     | 0 | 1 | 1 | 2 | 3 | 3 |
        //     +---+---+---+---+---+---+
        //               w
        //
        // Comparing self[r] against self[w-1], this is not a duplicate,
        // so swap self[r] and self[w] and advance r and w:
        //
        //                       r
        //     +---+---+---+---+---+---+
        //     | 0 | 1 | 2 | 1 | 3 | 3 |
        //     +---+---+---+---+---+---+
        //                   w
        //
        // Not a duplicate, repeat:
        //
        //                           r
        //     +---+---+---+---+---+---+
        //     | 0 | 1 | 2 | 3 | 1 | 3 |
        //     +---+---+---+---+---+---+
        //                       w
        //
        // Duplicate, advance r. End of slice. Split at w.

これがすべてで、要するに進んだ r (読み込みのインデックスみたいなもの) が重複検知をし、重複があった場合には逐一その重複した値たちをスライスの後ろにずらしていく。それで、最後に w が残った場所で split を行って、重複分をカットした新しいスライスを作るイメージかなと思われる。

該当するコードは下記。mem::swap が呼び出されていて、例のごとくちょっと不吉な匂いがするので、mem::swap を読んでいこうと思う。ちなみに PR のコメントには copy_nonoverlapping が問題っぽいという話が書いてあったのがヒントだと思う。

        unsafe {
            // Avoid bounds checks by using raw pointers.
            while next_read < len {
                let ptr_read = ptr.add(next_read);
                let prev_ptr_write = ptr.add(next_write - 1);
                if !same_bucket(&mut *ptr_read, &mut *prev_ptr_write) {
                    if next_read != next_write {
                        let ptr_write = prev_ptr_write.offset(1);
                        mem::swap(&mut *ptr_read, &mut *ptr_write);
                    }
                    next_write += 1;
                }
                next_read += 1;
            }
        }
yukiyuki

mem::swap の中では下記関数を呼び出している。この中で copy_nonoverlapping が1回呼び出しされるケースがある。T 型のサイズが32より小さかった場合には確実に呼び出される。

#[inline]
pub(crate) unsafe fn swap_nonoverlapping_one<T>(x: *mut T, y: *mut T) {
    // For types smaller than the block optimization below,
    // just swap directly to avoid pessimizing codegen.
    if mem::size_of::<T>() < 32 {
        // SAFETY: the caller must guarantee that `x` and `y` are valid
        // for writes, properly aligned, and non-overlapping.
        unsafe {
            let z = read(x);
            copy_nonoverlapping(y, x, 1);
            write(y, z);
        }
    } else {
        // SAFETY: the caller must uphold the safety contract for `swap_nonoverlapping`.
        unsafe { swap_nonoverlapping(x, y, 1) };
    }
}

そうでなかったケースでは swap_nonoverlapping_one という関数が呼び出されるが、

#[inline]
pub(crate) unsafe fn swap_nonoverlapping_one<T>(x: *mut T, y: *mut T) {
    // For types smaller than the block optimization below,
    // just swap directly to avoid pessimizing codegen.
    if mem::size_of::<T>() < 32 {
        // SAFETY: the caller must guarantee that `x` and `y` are valid
        // for writes, properly aligned, and non-overlapping.
        unsafe {
            let z = read(x);
            copy_nonoverlapping(y, x, 1);
            write(y, z);
        }
    } else {
        // SAFETY: the caller must uphold the safety contract for `swap_nonoverlapping`.
        unsafe { swap_nonoverlapping(x, y, 1) };
    }
}

これはさらに swap_nonoverlapping という関数が呼び出しされる。

#[inline]
#[stable(feature = "swap_nonoverlapping", since = "1.27.0")]
pub unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
    if cfg!(debug_assertions)
        && !(is_aligned_and_not_null(x)
            && is_aligned_and_not_null(y)
            && is_nonoverlapping(x, y, count))
    {
        // Not panicking to keep codegen impact smaller.
        abort();
    }

    let x = x as *mut u8;
    let y = y as *mut u8;
    let len = mem::size_of::<T>() * count;
    // SAFETY: the caller must guarantee that `x` and `y` are
    // valid for writes and properly aligned.
    unsafe { swap_nonoverlapping_bytes(x, y, len) }
}

さらに、swap_nonoverlapping_bytes という関数が呼ばれる。中で一度バイト形式に変換してこの関数を読んでいる。_bytes では、copy_nonoverlapping が合計で3回呼び出しされる。

このスクラップは2022/07/02にクローズされました