🕳️

トレイト境界の落とし穴

2022/02/21に公開

ジェネリック関数を書く前に

単純な関数でなければジェネリック関数を書く前に一旦具体的な型で書いて、必要なトレイト境界を洗い出してからジェネリック化することをお勧めします。
何故かというと関数呼び出しの深いところで追加のトレイト境界が必要になるとそれを呼び出す側の関数すべてにトレイト境界を追加する必要が出てくる場合があるためです。

HRTBが必要になるケース

例えば最大公約数(GCD)を求める関数を作りたくなったとします。
最大公約数は以下のように計算できます。

\gcd(x, y) = \begin{cases} y & (x = 0)\\ x & (y = 0)\\ \gcd(y, x \bmod y) & (\text{otherwize}) \end{cases}

これを素直にRustのコードにすると以下のようになります。

use num::Zero;
use std::ops::Rem;

fn gcd<T>(x: T, y: T) -> T
where
    T: Copy + Rem<Output = T> + Zero,
{
    if x.is_zero() {
        y
    } else if y.is_zero() {
        x
    } else {
        gcd(y, x % y)
    }
}

fn main() {
    assert_eq!(gcd(57, 42), 3);
}

このコードでは Copy トレイトを要求しているのでプリミティブ整数型では動きますが、num::BigInt などでは動きません。num::BigInt で動くようにするには Copy の代わりに Clone を要求して以下のようにすることが考えられます。

fn gcd<T>(x: T, y: T) -> T
where
    T: Clone + Rem<Output = T> + Zero,
{
    if x.is_zero() {
        y
    } else if y.is_zero() {
        x
    } else {
        gcd(y.clone(), x % y)
    }
}

このようにすれば num::BigInt でも問題なく動きます。しかし、再帰回数と同じ回数のnum::BigInt の clone が必要になります。これを防ぐために x % y&x % &y に変えることを考えます。(こうすれば%xyの所有権を奪わなくなるため) そのためにはトレイト境界を T: Rem<Output = T> から &T: Rem<Output = T> に変える必要がありそうです。しかし、これでは以下のようなエラーになります。

error[E0637]: `&` without an explicit lifetime name cannot be used here
 --> src/main.rs:7:5
  |
7 |     &T: Rem<Output = T>,
  |     ^ explicit lifetime name needed here

さて、ライフタイムが必要だと言われてしましましたがどうすれば良いでしょう? (以前の私はここで途方に暮れましたが)

答えは高階トレイト境界(HRTB: Higher-Rank Trait Bounds)を使うです。以下がHRTBを使った書き方です。

fn gcd<T>(x: T, y: T) -> T
where
    T: Zero,
    for<'x> &'x T: Rem<Output = T>,
{
    if x.is_zero() {
        y
    } else if y.is_zero() {
        x
    } else {
        let r = &x % &y;
        gcd(y, r)
    }
}

for<'x> &'x T: Rem<Output = T> は任意のライフタイム 'x について&'x T: Rem<Output = T>を満たすというトレイト境界です。(for<'x>\forall xと同様)

もしHRTBを使わない解決方法があったら教えて下さい。

謎のE0275

以下のようなref_addを定義したとします。このコードは何の問題も無く動きます。

use std::ops::Add;

fn ref_add<'a, T>(a: &'a T, b: &'a T) -> T
where
    &'a T: Add<Output = T>,
{
    a + b
}

fn main() {
    let a = 2i32;
    let b = 3i32;
    assert_eq!(ref_add(&a, &b), 5i32);
}

上のコードに適当な構造体の定義とトレイトの実装を加えて以下のようにします。

use std::ops::Add;

struct S<T>(T);

impl<'a, T> Add for &'a S<T>
where
    for <'x> &'x T: Add<Output = T>
{
    type Output = S<T>;
    fn add(self, rhs: Self) -> Self::Output {
        S(&self.0 + &rhs.0)
    }
}

fn ref_add<'a, T>(a: &'a T, b: &'a T) -> T
where
    &'a T: Add<Output = T>,
{
    a + b
}

fn main() {
    let a = 2i32;
    let b = 3i32;
    assert_eq!(ref_add(&a, &b), 5i32);
}

すると何故か以下のようなエラーになります。(きっとあなたの予想外の結果でしょう)

error[E0275]: overflow evaluating the requirement `for<'x> &'x std::simd::Simd<_, {_: usize}>: std::ops::Add`
  --> src/main.rs:25:16
   |
25 |     assert_eq!(ref_add(&a, &b), 5i32);
   |                ^^^^^^^
   |
   = help: consider increasing the recursion limit by adding a `#![recursion_limit = "256"]` attribute to your crate (`temp`)
note: required because of the requirements on the impl of `for<'x> std::ops::Add` for `&'x S<std::simd::Simd<_, {_: usize}>>`
  --> src/main.rs:5:13
   |
5  | impl<'a, T> Add for &'a S<T>
   |             ^^^     ^^^^^^^^
   = note: 127 redundant requirements hidden
   = note: required because of the requirements on the impl of `std::ops::Add` for `&S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<std::simd::Simd<_, {_: usize}>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>`
note: required by a bound in `ref_add`
  --> src/main.rs:17:12
   |
15 | fn ref_add<'a, T>(a: &'a T, b: &'a T) -> T
   |    ------- required by a bound in this
16 | where
17 |     &'a T: Add<Output = T>,
   |            ^^^^^^^^^^^^^^^ required by this bound in `ref_add`

For more information about this error, try `rustc --explain E0275`.

このエラーを防ぐには ref_add の呼び出し時に型を明示すれば良いです。つまり ref_add(&a, &b)ref_add::<i32>(&a, &b) とすれば問題なくコンパイルが通ります。

謎のE0277

追記: 1.61.0-nightly(1eb7258 2022-03-08) で修正済み。

fn ref_add<T>(a: &T, b: &T) -> T
where
    T: for<'x> From<<&'x T as Add>::Output>,
    for<'x> &'x T: Add,
{
    T::from(a + b)
}

fn main() {
    let a = 2i32;
    let b = 3i32;
    assert_eq!(ref_add(&a, &b), 5i32);
}

上のコードは下のようなエラーになります。

error[E0277]: the trait bound `for<'x> i32: From<<&'x i32 as Add>::Output>` is not satisfied
  --> src/main.rs:14:16
   |
14 |     assert_eq!(ref_add(&a, &b), 5i32);
   |                ^^^^^^^ the trait `for<'x> From<<&'x i32 as Add>::Output>` is not implemented for `i32`
   |
note: required by a bound in `ref_add`
  --> src/main.rs:5:8
   |
3  | fn ref_add<T>(a: &T, b: &T) -> T
   |    ------- required by a bound in this
4  | where
5  |     T: for<'x> From<<&'x T as Add>::Output>,
   |        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `ref_add`
help: consider introducing a `where` bound, but there might be an alternative better way to express this requirement
   |
11 | fn main() where i32: for<'x> From<<&'x i32 as Add>::Output> {
   |           +++++++++++++++++++++++++++++++++++++++++++++++++

For more information about this error, try `rustc --explain E0277`.

これは以下のようにHRTBを使わないで実装するか、呼び出し時に型を明示すれば問題なくコンパイルが通ります。

fn ref_add<'a, T>(a: &'a T, b: &'a T) -> T
where
    T: From<<&'a T as Add>::Output>,
    &'a T: Add,
{
    T::from(a + b)
}

ちなみにこの問題については https://github.com/rust-lang/rust/issues/94160 で報告済みです。

Discussion