Vec::dedup_by の最適化
10%くらい速度向上した模様。ptr::copy_nonoverlapping
ではなく、ptr::copy
を使用するようにした修正らしい。
Vec::extend
のときと同様に、drop する際に最適化をかける形式をとっているように見える。最終的には mem::forget する感じかな。あとで詳細はまとめる。
先日話題になっていた Rust のパターン集(の案)でいうところの panic guard とかに該当しそうな感じ。
修正前は、Slice
の partition_dedup_by
というメソッドと、その結果取得できる新しいベクタの要素数を保持しておき、最後にその長さに truncate
するというロジックをとっていた。truncate
は中で不要になった領域を drop
するだけなので、partition_dedup_by
に少し問題があったと考えるのがよさそう。
全然関係ないのだけど、dedup
ってたしか重複排除のことだったよな…?と思って、記事を漁っていたらこんな記事を見つけた。
余計なことをせずに全部 clone したほうが結果的には速いという調査結果まで載っており、こういうどのタイミングで clone すると遅くなるかって難しいよな〜、ちゃんとベンチを地道に取るしかないよね〜、と思うなど。
partition_dedup_by
のアルゴリズムについては、下記コメントに記載されている。
// Although we have a mutable reference to `self`, we cannot make
// *arbitrary* changes. The `same_bucket` calls could panic, so we
// must ensure that the slice is in a valid state at all times.
//
// The way that we handle this is by using swaps; we iterate
// over all the elements, swapping as we go so that at the end
// the elements we wish to keep are in the front, and those we
// wish to reject are at the back. We can then split the slice.
// This operation is still `O(n)`.
//
// Example: We start in this state, where `r` represents "next
// read" and `w` represents "next_write`.
//
// r
// +---+---+---+---+---+---+
// | 0 | 1 | 1 | 2 | 3 | 3 |
// +---+---+---+---+---+---+
// w
//
// Comparing self[r] against self[w-1], this is not a duplicate, so
// we swap self[r] and self[w] (no effect as r==w) and then increment both
// r and w, leaving us with:
//
// r
// +---+---+---+---+---+---+
// | 0 | 1 | 1 | 2 | 3 | 3 |
// +---+---+---+---+---+---+
// w
//
// Comparing self[r] against self[w-1], this value is a duplicate,
// so we increment `r` but leave everything else unchanged:
//
// r
// +---+---+---+---+---+---+
// | 0 | 1 | 1 | 2 | 3 | 3 |
// +---+---+---+---+---+---+
// w
//
// Comparing self[r] against self[w-1], this is not a duplicate,
// so swap self[r] and self[w] and advance r and w:
//
// r
// +---+---+---+---+---+---+
// | 0 | 1 | 2 | 1 | 3 | 3 |
// +---+---+---+---+---+---+
// w
//
// Not a duplicate, repeat:
//
// r
// +---+---+---+---+---+---+
// | 0 | 1 | 2 | 3 | 1 | 3 |
// +---+---+---+---+---+---+
// w
//
// Duplicate, advance r. End of slice. Split at w.
これがすべてで、要するに進んだ r
(読み込みのインデックスみたいなもの) が重複検知をし、重複があった場合には逐一その重複した値たちをスライスの後ろにずらしていく。それで、最後に w
が残った場所で split を行って、重複分をカットした新しいスライスを作るイメージかなと思われる。
該当するコードは下記。mem::swap
が呼び出されていて、例のごとくちょっと不吉な匂いがするので、mem::swap
を読んでいこうと思う。ちなみに PR のコメントには copy_nonoverlapping
が問題っぽいという話が書いてあったのがヒントだと思う。
unsafe {
// Avoid bounds checks by using raw pointers.
while next_read < len {
let ptr_read = ptr.add(next_read);
let prev_ptr_write = ptr.add(next_write - 1);
if !same_bucket(&mut *ptr_read, &mut *prev_ptr_write) {
if next_read != next_write {
let ptr_write = prev_ptr_write.offset(1);
mem::swap(&mut *ptr_read, &mut *ptr_write);
}
next_write += 1;
}
next_read += 1;
}
}
mem::swap
の中では下記関数を呼び出している。この中で copy_nonoverlapping が1回呼び出しされるケースがある。T
型のサイズが32より小さかった場合には確実に呼び出される。
#[inline]
pub(crate) unsafe fn swap_nonoverlapping_one<T>(x: *mut T, y: *mut T) {
// For types smaller than the block optimization below,
// just swap directly to avoid pessimizing codegen.
if mem::size_of::<T>() < 32 {
// SAFETY: the caller must guarantee that `x` and `y` are valid
// for writes, properly aligned, and non-overlapping.
unsafe {
let z = read(x);
copy_nonoverlapping(y, x, 1);
write(y, z);
}
} else {
// SAFETY: the caller must uphold the safety contract for `swap_nonoverlapping`.
unsafe { swap_nonoverlapping(x, y, 1) };
}
}
そうでなかったケースでは swap_nonoverlapping_one
という関数が呼び出されるが、
#[inline]
pub(crate) unsafe fn swap_nonoverlapping_one<T>(x: *mut T, y: *mut T) {
// For types smaller than the block optimization below,
// just swap directly to avoid pessimizing codegen.
if mem::size_of::<T>() < 32 {
// SAFETY: the caller must guarantee that `x` and `y` are valid
// for writes, properly aligned, and non-overlapping.
unsafe {
let z = read(x);
copy_nonoverlapping(y, x, 1);
write(y, z);
}
} else {
// SAFETY: the caller must uphold the safety contract for `swap_nonoverlapping`.
unsafe { swap_nonoverlapping(x, y, 1) };
}
}
これはさらに swap_nonoverlapping
という関数が呼び出しされる。
#[inline]
#[stable(feature = "swap_nonoverlapping", since = "1.27.0")]
pub unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
if cfg!(debug_assertions)
&& !(is_aligned_and_not_null(x)
&& is_aligned_and_not_null(y)
&& is_nonoverlapping(x, y, count))
{
// Not panicking to keep codegen impact smaller.
abort();
}
let x = x as *mut u8;
let y = y as *mut u8;
let len = mem::size_of::<T>() * count;
// SAFETY: the caller must guarantee that `x` and `y` are
// valid for writes and properly aligned.
unsafe { swap_nonoverlapping_bytes(x, y, len) }
}
さらに、swap_nonoverlapping_bytes
という関数が呼ばれる。中で一度バイト形式に変換してこの関数を読んでいる。_bytes
では、copy_nonoverlapping
が合計で3回呼び出しされる。