Vec::retain の最適化
この最適化がおもしろかったので、あとで読んでメモを取る。理解できるかがちょっと謎だけど。
おおまかな変更としては単純で、
-
swap
の廃止。 -
truncate
の廃止。
で性能を上げたという感じになっていそう。
ベンチマークを見る感じ、いいケースだと倍以上の性能が出るようになっていそう。
元の実装
元の実装は下記のようになっている。
#[stable(feature = "rust1", since = "1.0.0")]
pub fn retain<F>(&mut self, mut f: F)
where
F: FnMut(&T) -> bool,
{
let len = self.len();
let mut del = 0;
{
let v = &mut **self;
for i in 0..len {
if !f(&v[i]) {
del += 1;
} else if del > 0 {
v.swap(i - del, i);
}
}
}
if del > 0 {
self.truncate(len - del);
}
}
たとえば、[1,2,3,4] という配列に対して、偶数を残す retain を行った場合、retain 後の配列は [2, 4] となる。これは次のようにして算出されることがわかる。
インデックスと値の対応表
0 | 1 | 2 | 3 |
---|---|---|---|
1 | 2 | 3 | 4 |
- i = 0 でループ
- 1 は 1 % 2 == 0 を満たさないので、del + 1 される。この時点で del = 1
- 配列自体に変化なし。
- i = 1 でループ
- 2 は 2 % 2 == 0 を満たすので、次の del > 0 の判定に移る。
- del = 1 より、これは満たされる。
- この時点で、i - del = 1 - 1 = 0 番目と、i = 1 番目の要素が swap される。 (後ほど説明するが、ここで
copy
が1回、copy_nonoverlapping
が2回呼び出される) - 配列は [2, 1, 3, 4] となる。
- i = 2 でループ(配列 = [2, 1, 3, 4])
- 3 は 3 % 2 == 0 を満たさないので、del + 1 される。この時点で del = 1 + 1 = 2
- 配列自体に変化なし。
- i = 3 でループ(配列 = [2, 1, 3, 4])
- 4 は 4 % 2 == 0 を満たすので、次の del > 0 の判定に移る。
- del = 2 より、これは満たされる。
- この時点で、i - del = 3 - 2 = 1 番目と、i = 3 番目の要素が swap される。 (後ほど説明するが、ここで
copy
が1回、copy_nonoverlapping
が2回呼び出される) - 配列は [2, 4, 3, 1] となる。
実質 copy
は2回行われていることになる。copy_nonoverlapping
は4回行われていることになる。
このループが終了した後、
- del = 2, len = 4 より、truncate(2) が求まる。左から2個の要素を残し、それ以外を truncate する。
- つまり、後ろ2つがカットされる。
- 最終的に取り出される配列は [2, 4] となる。
という操作が行われ、新しい配列が完成する。
修正後の方針
修正後の方針として特徴的なのは、インデックスによる探索&swapping (copy が都度走る)から、そもそも先に要素ごとのポインタを取っておいて、retain 対象でない要素が来たらすぐに drop しておき、要素を詰めた配列を move (memcpy
)によって用意しておいて、最後の最後で1回だけ copy (memmove
)を走らせるという構成に変わっている点だった。こうすると copy が呼び出される回数が減って、大幅な速度上昇につながった、という感じかと思う。
以降のスクラップでは、つらつらと調査したメモを書いておく。
Rust の Vec
はポインタを用いた処理がふんだんに利用されている。もちろんそうした処理は本質的には unsafe だが、unsafe な箇所には必ずなぜ安全を保証できているかが補足されている。そうしたコメントを読むだけでも安全性の保証の仕方の勉強になる。加えて、そもそも筆者のようにあまりポインタ演算に慣れていないプログラマにとっては、こうした処理を写経するだけで勉強になる。unsafe 周りの絶好の教材だと言える。
PR を見ると、swap
を廃止してメモリへの書き込みを減らすと書いてある。
まず swap 関数についてだが、Vec::swap
の場合は、インデックス a
と b
にある要素同士を入れ替えるというものになっている。コメントのコードをそのまま Playground で動かした例はこちら。
この例の場合、a, b, c, d の並びのうち、1番目と3番目を入れ替えている。元の配列の1番目はb
、3番目はd
となる。これらを入れ替えるから、swap 後の配列は a, d, c, b となる。
/// Swaps two elements in the slice.
///
/// # Arguments
///
/// * a - The index of the first element
/// * b - The index of the second element
///
/// # Panics
///
/// Panics if `a` or `b` are out of bounds.
///
/// # Examples
///
/// ```
/// let mut v = ["a", "b", "c", "d"];
/// v.swap(1, 3);
/// assert!(v == ["a", "d", "c", "b"]);
/// ```
実装を見てみる。Vec::swap
を見てみると、下記のような実装になっている。
#[stable(feature = "rust1", since = "1.0.0")]
#[inline]
pub fn swap(&mut self, a: usize, b: usize) {
// Can't take two mutable loans from one vector, so instead just cast
// them to their raw pointers to do the swap.
let pa: *mut T = &mut self[a];
let pb: *mut T = &mut self[b];
// SAFETY: `pa` and `pb` have been created from safe mutable references and refer
// to elements in the slice and therefore are guaranteed to be valid and aligned.
// Note that accessing the elements behind `a` and `b` is checked and will
// panic when out of bounds.
unsafe {
ptr::swap(pa, pb);
}
}
手順を簡単にまとめると(ポインタは「作る」でいいのかな…「取る」かな…)、
-
a
番目の要素の可変な生ポインタを作る。 -
b
番目の要素の可変な生ポインタを作る。 -
std::ptr::swap
を呼び出す。
生ポインタを作るくだりは、std::ptr::swap
がそれを要求するから。std::ptr::swap
を見てみる。
#[inline]
#[stable(feature = "rust1", since = "1.0.0")]
pub unsafe fn swap<T>(x: *mut T, y: *mut T) {
// Give ourselves some scratch space to work with.
// We do not have to worry about drops: `MaybeUninit` does nothing when dropped.
let mut tmp = MaybeUninit::<T>::uninit();
// Perform the swap
// SAFETY: the caller must guarantee that `x` and `y` are
// valid for writes and properly aligned. `tmp` cannot be
// overlapping either `x` or `y` because `tmp` was just allocated
// on the stack as a separate allocated object.
unsafe {
copy_nonoverlapping(x, tmp.as_mut_ptr(), 1);
copy(y, x, 1); // `x` and `y` may overlap
copy_nonoverlapping(tmp.as_ptr(), y, 1);
}
}
ここは C の授業で習うような一般的なスワップっぽいものが書いてある気がした。
MaybeUninit
がわからなすぎるが、要するに「初期化されていないかもしれない」状態を示すための型のようだ。未初期化領域は単純に未定義動作になるので、たぶんそれを防いでいるのだと思う。本来ならば受け取りの T 型の初期化用関数を呼び出したいのだと思うが、それをやると、たとえば T
のトレイト境界に Default
を入れる必要が出てくるとか、余計なメモリ領域を使用するなどいろいろ面倒なのでこれを使っているのかなと思った。
さて話はそれたが、
-
tmp
をMaybeUninit
で初期化。 -
x
とtmp
に対するcopy_nonoverlapping
を呼び出し。 -
copy
でy
をx
にcopy
する。 -
tmp
とy
に対するcopy_nonoverlapping
を呼び出し。
これだけ見ると、
#include <stdio.h>
void swap(int *x, int *y) {
int tmp;
tmp = *x;
*x = *y;
*y = tmp;
}
みたいなことをやっているだけに見えるが、copy_nonoverlapping
という関数が何をしているのかが知りたい。調べてみる。
まだ確証はもてていないが、この時点で合計3回の copy に関する操作が走っていることになる。これが今回の最適化のキーポイントになっているかもしれない。
copy_nonoverlapping は実質的な memcpy。
copy は実質的な memmove。copy は重なりがある可能性がある。
重なりありの memmove は遅いパターンがあるらしい。
重なりがない場合の copy は copy_nonoverlapping と実質的に速度は変わらないケースがほとんどのようだが、重なりありの場合は、重なった分を 1bit ずつちまちまコピーを走らせるので、このケースは結構遅くなる(らしい)。C の実装内容を見てみてもそんな感じだった。ベンチマークの結果で大きな差が出た箇所は、こうした重なりの有無が影響しているのかもしれない。
次見るべきは truncate
の方で、これは指定した数、左から要素を残す関数っぽい。
truncate に 1 を入れると1、2を入れると1, 2、3を入れると1,2,3を残す。ちなみに4を入れてもパニックはしなかった。
fn main() {
let mut v = vec![1, 2, 3];
v.truncate(0);
assert_eq!(v, []);
}
最終的に利用される drop_in_place
を、想定した長さ(= Vec の len)以上に対してかけると、不要なメモリ領域も含む slice を返すことになり、メモリ安全ではないのではと一瞬思ったが、unsafe
ブロックの最初の1行目で if len > self.len
だった場合に何もせずに返す旨の実装がされている。これで安全でない処理は弾いているから問題ない。結果、3の長さの Vec に対して4を指定しても、リスト全部を返すだけだから問題ない。という実装になっているということがわかった。
/// Shortens the vector, keeping the first `len` elements and dropping
/// the rest.
///
/// If `len` is greater than the vector's current length, this has no
/// effect.
///
/// The [`drain`] method can emulate `truncate`, but causes the excess
/// elements to be returned instead of dropped.
///
/// Note that this method has no effect on the allocated capacity
/// of the vector.
///
/// # Examples
///
/// Truncating a five element vector to two elements:
///
/// ```
/// let mut vec = vec![1, 2, 3, 4, 5];
/// vec.truncate(2);
/// assert_eq!(vec, [1, 2]);
/// ```
///
/// No truncation occurs when `len` is greater than the vector's current
/// length:
///
/// ```
/// let mut vec = vec![1, 2, 3];
/// vec.truncate(8);
/// assert_eq!(vec, [1, 2, 3]);
/// ```
///
/// Truncating when `len == 0` is equivalent to calling the [`clear`]
/// method.
///
/// ```
/// let mut vec = vec![1, 2, 3];
/// vec.truncate(0);
/// assert_eq!(vec, []);
/// ```
///
/// [`clear`]: Vec::clear
/// [`drain`]: Vec::drain
#[stable(feature = "rust1", since = "1.0.0")]
pub fn truncate(&mut self, len: usize) {
// This is safe because:
//
// * the slice passed to `drop_in_place` is valid; the `len > self.len`
// case avoids creating an invalid slice, and
// * the `len` of the vector is shrunk before calling `drop_in_place`,
// such that no value will be dropped twice in case `drop_in_place`
// were to panic once (if it panics twice, the program aborts).
unsafe {
if len > self.len {
return;
}
let remaining_len = self.len - len;
let s = ptr::slice_from_raw_parts_mut(self.as_mut_ptr().add(len), remaining_len);
self.len = len;
ptr::drop_in_place(s);
}
}
drop_in_place
の挙動が、実装を読んだだけではわからなかったので、ちょっとサンプルコードをまずは動かしてみた。
use std::ptr;
use std::rc::Rc;
fn main() {
let last = Rc::new(1);
let weak = Rc::downgrade(&last);
let mut v = vec![Rc::new(0), last];
unsafe {
let ptr = &mut v[1] as *mut _;
v.set_len(1);
ptr::drop_in_place(ptr);
}
assert_eq!(v, &[0.into()]);
assert!(weak.upgrade().is_none());
}
図にしたいが、
- [0, 1] の配列を作る。(値は
Rc
になっている)1 については弱参照を作っておく。 -
ptr
に v[1] = 1 を指す可変なポインタを格納しておく。 - ベクタの len を 1 にする。
-
ptr
を destruct する。 - そうすると、1 は解放された状態になるので、実質ベクタに残るのは v[0] = 0 のみとなる。
- v[1] に対する弱参照も消ているというチェックもしている。
要するに引数で投げ込んだポインタの指す領域の解放を行っているという感じっぽい。
弱参照をはじめとする Rc の基本的な概念や操作はこちらの記事が詳しい。
新実装の方も、基本的に copy_nonoverlapping
を使用していたり、drop_in_place
を使用していたりする点は基本的に変わりがない。一方で、swap を使用しない形にロジックを修正しているので、その部分の実装がだいぶ変わっている。
#[stable(feature = "rust1", since = "1.0.0")]
pub fn retain<F>(&mut self, mut f: F)
where
F: FnMut(&T) -> bool,
{
let original_len = self.len();
// Avoid double drop if the drop guard is not executed,
// since we may make some holes during the process.
unsafe { self.set_len(0) };
// Vec: [Kept, Kept, Hole, Hole, Hole, Hole, Unchecked, Unchecked]
// |<- processed len ->| ^- next to check
// |<- deleted cnt ->|
// |<- original_len ->|
// Kept: Elements which predicate returns true on.
// Hole: Moved or dropped element slot.
// Unchecked: Unchecked valid elements.
//
// This drop guard will be invoked when predicate or `drop` of element panicked.
// It shifts unchecked elements to cover holes and `set_len` to the correct length.
// In cases when predicate and `drop` never panick, it will be optimized out.
struct BackshiftOnDrop<'a, T, A: Allocator> {
v: &'a mut Vec<T, A>,
processed_len: usize,
deleted_cnt: usize,
original_len: usize,
}
impl<T, A: Allocator> Drop for BackshiftOnDrop<'_, T, A> {
fn drop(&mut self) {
if self.deleted_cnt > 0 {
// SAFETY: Trailing unchecked items must be valid since we never touch them.
unsafe {
ptr::copy(
self.v.as_ptr().add(self.processed_len),
self.v.as_mut_ptr().add(self.processed_len - self.deleted_cnt),
self.original_len - self.processed_len,
);
}
}
// SAFETY: After filling holes, all items are in contiguous memory.
unsafe {
self.v.set_len(self.original_len - self.deleted_cnt);
}
}
}
let mut g = BackshiftOnDrop { v: self, processed_len: 0, deleted_cnt: 0, original_len };
while g.processed_len < original_len {
// SAFETY: Unchecked element must be valid.
let cur = unsafe { &mut *g.v.as_mut_ptr().add(g.processed_len) };
if !f(cur) {
// Advance early to avoid double drop if `drop_in_place` panicked.
g.processed_len += 1;
g.deleted_cnt += 1;
// SAFETY: We never touch this element again after dropped.
unsafe { ptr::drop_in_place(cur) };
// We already advanced the counter.
continue;
}
if g.deleted_cnt > 0 {
// SAFETY: `deleted_cnt` > 0, so the hole slot must not overlap with current element.
// We use copy for move, and never touch this element again.
unsafe {
let hole_slot = g.v.as_mut_ptr().add(g.processed_len - g.deleted_cnt);
ptr::copy_nonoverlapping(cur, hole_slot, 1);
}
}
g.processed_len += 1;
}
// All item are processed. This can be optimized to `set_len` by LLVM.
drop(g);
}
大雑把なアルゴリズムは、
-
BackshiftOnDrop
を作る。 - 1で作った構造体の
processed_len
がoriginal_len
(元の配列の長さ) をこえるまでは、ループ処理を回し続ける。- retain の条件に一致しない場合は、その要素を drop しておく。
- 要素を削除したカウンタが0より大きければ、削除分を反映した配列の状態を move しておく。
- 1を drop する。drop 時に後処理として、下記2つが走る。
- 要素を削除したカウンタが0より大きければ、処理した分を copy する。
- ベクタ自身の持つサイズを現状のものに調整する。
元の実装と比べると、move の呼び出される回数はそこまで大差ないかもしれないが、copy は最後の drop 時の1回だけ行われるようになるので、まず間違いなく回数が減る。
let mut g = BackshiftOnDrop {
v: self_,
processed_len: 0,
deleted_cnt: 0,
original_len,
};
while g.processed_len < original_len {
// SAFETY: Unchecked element must be valid.
let cur = unsafe { &mut *g.v.as_mut_ptr().add(g.processed_len) };
if !f(cur) {
// Advance early to avoid double drop if `drop_in_place` panicked.
g.processed_len += 1;
g.deleted_cnt += 1;
// SAFETY: We never touch this element again after dropped.
unsafe { ptr::drop_in_place(cur) };
// We already advanced the counter.
continue;
}
if g.deleted_cnt > 0 {
// SAFETY: `deleted_cnt` > 0, so the hole slot must not overlap with current element.
// We use copy for move, and never touch this element again.
unsafe {
let hole_slot = g.v.as_mut_ptr().add(g.processed_len - g.deleted_cnt);
ptr::copy_nonoverlapping(cur, hole_slot, 1);
}
}
g.processed_len += 1;
}
先ほどと同じように [1, 2, 3, 4] の配列を偶数だけ取り出すという操作を行う。最終的な成果物は[2,4]になっているはず。
新しいアルゴリズムのキーワードになるのは
- processed_len
- deleted_cnt
- original_len
で、それらを埋めながら見ていく。
あとはソースコードのコメントにもあるとおり、
Kept: 今回 retain される対象の要素。
Hole: move あるいは drop された要素。
Unchecked: まだチェックしていない正当な要素。
これらの用語はあとで登場する予定。
初期時点では、original_len は配列の大きさを最初に入れるので、まず4になる。processed_lenと deleted_cnt はそれぞれ初期化で0が入れられる。
ちょっとポインタの操作が複雑なので、要素の開始位置のポインタをまずは決めてみることにする。各要素のサイズは、i32
だから 4bit ずつかな?
address | 0x7ffe4d8f54a0 | 0x7ffe4d8f54a4 | 0x7ffe4d8f54a8 | 0x7ffe4d8f54ac |
---|---|---|---|---|
index | 0 | 1 | 2 | 3 |
value | 1i32 | 2i32 | 3i32 | 4i32 |
- 最初の
g
を作る。processed_len = 0, deleted_cnt = 0, original_len = 4 - 1回目の while (processed_len = 0 < original_len = 4)
-
cur
を作る。cur
=g.v.as_mut_ptr().add(0)
=0x7ffe4d8f54a0
。 -
cur
の示す先の値は1i32
なので、最初の if ブロックの条件を満たす。processed_len = 1, deleted_cnt = 1 となる。また、cur
は drop される。配列は[(Hole), 2 (Unchecked), 3 (Unchecked), 4 (Unchecked)]
になっているはず。 - 次のループに飛ぶ。
-
- 2回目の while (processed_len = 1 < original_len = 4, deleted_cnt = 1)
-
cur
を作る。cur
=g.v.as_mut_ptr().add(1)
=0x7ffe4d8f54a4
。 -
cur
の示す先の値は2i32
なので、最初の if ブロックの条件は満たさない。次。 - deleted_cnt = 1 より、条件を満たす。
-
hole_slot
はg.v.as_mut_ptr().add(1 - 1)
=g.v.as_mut_ptr().add(0)
=0x7ffe4d8f54a0
(さっき drop したところ)。 -
cur
をhole_slot
に move する。つまり、[2 (Kept), (Hole), 3 (Unchecked), 4 (Unchecked)] となっているはず。
-
- processed_len = 2
-
- 3回目の while (processed_len = 2 < original_len = 4, deleted_cnt = 1)
-
cur
を作る。cur
=g.v.as_mut_ptr().add(2)
=0x7ffe4d8f54a8
。 -
cur
の示す先の値は3i32
なので、最初の if ブロックの条件を満たす。processed_len = 3, deleted_cnt = 2 となる。また、cur
は drop される。配列は[2 (Kept), (Hole), (Hole), 4 (Unchecked)]
になっているはず。 - 次のループに飛ぶ。
-
- 4回目の while (processed_len = 3 < original_len = 4, deleted_cnt = 2)
-
cur
を作る。cur
=g.v.as_mut_ptr().add(3)
=0x7ffe4d8f54ac
。 -
cur
の示す先の値は4i32
なので、最初の if ブロックの条件は満たさない。次。 - deleted_cnt = 2 より、条件を満たす。
-
hole_slot
はg.v.as_mut_ptr().add(3 - 2)
=g.v.as_mut_ptr().add(1)
=0x7ffe4d8f54a4
-
cur
をhole_slot
に move する。つまり、[2 (Kept), 4 (Kept), (Hole), (Hole)] となっているはず。
-
- processed_len = 4
-
- 5回目のループは条件を満たさずできない。
- drop 処理を行う。
- processed_len = 4, deletec_cnt = 2
- 終わった時点でのメモリの状態は、0x7ffe4d8f54a0 = 2, 0x7ffe4d8f54a4 =4, 0x7ffe4d8f54a8 = Hole, 0x7ffe4d8f54ac = Hole
- copy(
g.v.as_ptr().add(4) = 0x7ffe4d8f54b0
,g.v.as_mut_ptr().add(2) = 0x7ffe4d8f54a8
, 0) が実行される。count に 0 が入っているので、実質何もしない。 - Vec の len は2がセットされる。[2, 4] が len に含まれることになった。
copy
は drop 処理の中で1度行われるだけに変わっている。ここが、条件によっては大きなパフォーマンス向上の原動力になったことがわかる。
copy_nonoverlapping
も2回になっていて減ってる。
copy の挙動がよくわからないので、実質同じような状況を作って確かめてみる。
copy の定義としては
pub unsafe fn copy<T>(src: *const T, dst: *mut T, count: usize)
src から dst に向けて、count * size_of::<T> byte 分のコピーを走らせる。つまりたとえば、count に1を入れて、T
が i32
だった場合は、1*4bytes = 32bit 分、src からコピーをし、dst に入れる。0だった場合は0byteのコピーなのだから、何もしないという行為に等しいはず。
実験したソースコードは下記。
下記のような形にして、第3引数に0を入れておく。
fn main() {
unsafe {
let mut v = vec![1, 2, 3, 4];
let cur = v.as_ptr();
let cur1 = v.as_ptr().add(0);
let cur2 = v.as_mut_ptr().add(1);
let cur3 = v.as_mut_ptr().add(2);
let cur4 = v.as_ptr().add(3);
let cur5 = v.as_ptr().add(4);
println!("{:?} = {}", cur1, *cur1);
println!("{:?} = {}", cur2, *cur2);
println!("{:?} = {}", cur3, *cur3);
println!("{:?} = {}", cur4, *cur4);
println!("{:?} = {}", cur5, *cur5);
std::ptr::copy(cur5, cur3, 0);
println!("source = {:?}", *cur5);
println!("dest = {:?}", *cur3);
}
}
すると、実行した結果は下記のようになった。
0x5607ad21fad0 = 1
0x5607ad21fad4 = 2
0x5607ad21fad8 = 3
0x5607ad21fadc = 4
0x5607ad21fae0 = 0
source = 0
dest = 3
memcpy や memmmove を使ったことがないのでよくわからないが、どうやら 0 を入れると何もしないという挙動をする感じがする。たとえば1を入れると、コピー処理(memmove)が走ることが見て取れる。source が dest にコピーされた様子が見て取れる(両方とも0という値になっている)。
fn main() {
unsafe {
let mut v = vec![1, 2, 3, 4];
let cur = v.as_ptr();
let cur1 = v.as_ptr().add(0);
let cur2 = v.as_mut_ptr().add(1);
let cur3 = v.as_mut_ptr().add(2);
let cur4 = v.as_ptr().add(3);
let cur5 = v.as_ptr().add(4);
println!("{:?} = {}", cur1, *cur1);
println!("{:?} = {}", cur2, *cur2);
println!("{:?} = {}", cur3, *cur3);
println!("{:?} = {}", cur4, *cur4);
println!("{:?} = {}", cur5, *cur5);
std::ptr::copy(cur5, cur3, 1); // 1 * size_of::<i32> = 1 * 4 bytes = 32 bit
println!("source = {:?}", *cur5);
println!("dest = {:?}", *cur3);
}
}
0x55820f77cad0 = 1
0x55820f77cad4 = 2
0x55820f77cad8 = 3
0x55820f77cadc = 4
0x55820f77cae0 = 0
source = 0
dest = 0
今回用意したケースだと、そもそも単純すぎて前後であまり旨味が変わらない感じになってしまったが、もう少し複雑な例だと旨味がよくわかるかもしれない。そもそも retain の条件に合致した要素数に比例して copy (memmove) が発生していた部分が、単純に1回で済むように変わっているので、要素数が長ければ長いほど効果が出てきそう。また、メモリ領域のダブりがあると、memmove は遅くなるらしいということもわかった。こうした2つのケースでは、劇的に性能が改善している可能性があると思った。ベンチマークの結果もそのようになっていそう。
最終的にこちらの記事にまとめました。