😇

美しきBarrett乗算よ(詠嘆)

に公開
7

はじめに

Barrett乗算実装してみて、その美しさに気付いたので共有します。とくに、以下の2点が顕著です。

  • ゼロ除算を除く任意のu64 % u32を高速に計算できる。当然、u32 * u32 % u32も計算できる。
  • Lemireの高速剰余アルゴリズムとマジックナンバーを共有している。

他にも優れた点はたくさんあります。

注意事項

  • 多倍長の演算を避けるため、64bitマシンで32bitの乗算を考えます。
  • 1での除算には追加の処理が必要なので除外します。対応の一例も紹介します。

Barret乗算のここがすごい

オーバーヘッドが小さい

struct Barrett {
    // 64bitで持っておくと便利
    modulus: u64,
    magic: u64,
}

impl Barrett {
    fn new(modulus: u32) -> Self {
        // 後述のように、1除算は特別な対応が必要。必要ないので切り捨てる。
        assert!(modulus > 1);

        let modulus = modulus as u64;
        // 詳細は後述。1除算の場合のみ、オーバーフローして0になる。
        let magic = (u64::MAX / modulus).wrapping_add(1);

        Self { modulus, magic }
    }
}

Barrett乗算はマジックナンバーを1つしか必要としません。Montgomery乗算とPlantard乗算は2つ必要です[1]。さらに、マジックナンバーの計算は除算と加算をそれぞれ1回だけです。2回使えば損失はほとんどゼロです。

時間的にも空間的にもオーバーヘッドは僅かです。

あらゆる32bit剰余乗算ができる

実装例の前にアルゴリズムの概要を説明します。Barrett乗算のマジックナンバーは、法をM \in [1, 2^{32})として、次のように定義されます。

R = \left\lceil \frac{2^{64}}{M} \right\rceil = \left\lfloor \frac{2^{64} - 1}{M} \right\rfloor + 1

任意の整数A \in [0, 2^{64})について、次が成り立ちます。

\left\lfloor \frac{A}{M} \right\rfloor \le \frac{A R}{2^{64}} < \left\lfloor \frac{A}{M} \right\rfloor + 2

つまり、\lfloor A / M \rfloorを高々1の誤差で得られます。

証明

定義より、次式が成り立ちます。

\begin{align*} A &= B M + C \, (0 \le B, 0 \le C \lt M)\\ R M &= 2^{64} + D \, (0 \le D \lt M) \end{align*}

したがって、

\begin{align*} A R &= (B M + C) R = B (2^{64} + D) + C R\\ B D + C R &\le (B + R)(M - 1)\\ &= (A - C) + (2^{64} + D) - (B + R)\\ &= (A + 2^{64}) + D - (B + C + R) \end{align*}

明らかに、A R \ge B 2^{64}が成り立ちます。また、A < 2^{64}かつD < M < Rより、A R < (B + 2) 2^{64}です。

以上より、Barrett乗算は次のように書けます。

impl Barrett {
    fn residue64(&self, x: u64) -> u64 {
        // `magic + x + 0`の上位64bit。widening_mul()はまだnightly。
        let quot = self.magic.carrying_mul(x, 0);
        // `quot`は32bitなので、オーバーフローしない。
        // アンダーフローしたら、`quot`を1だけ大きく見積もっていたことになる。
        let (residue, b) = x.overflowing_sub(quot * self.modulus);

        if b { residue.wrapping_add(self.modulus) } else { residue }
    }

    fn mul(&self, a: u32, b: u32) -> u32 {
        self.residue64(a as u64 * b as u64) as u32
    }
}
1除算の対応

1除算の場合、R = 2^{64}がオーバーフローしてゼロになってしまいます。Compiler Explorerに実装例があります。

整数をそのまま使える

Montgomery乗算やPlantard乗算では、剰余と一対一対応した別の値を内部的に利用します。一方、Barrett乗算ではそのような工夫は必要ありません。上記の事実と合わせて次のような性質を導けます。

  • 特別な初期化処理が要らない。
  • 自由に法を変えることができる。

Lemireの高速剰余アルゴリズムをゼロコストで使える

Barrett乗算は商を近似することで剰余を計算します。これには正規化のコストがかかります。Lemireの高速剰余アルゴリズムでは、商を介することなく剰余を直接求めます。おもしろいことに、マジックナンバーは同じです。

64bitマシンではu32 % u32を乗算2回で計算できます。さらに、乗算1回でu32 % u32 == 0を計算できます。

impl Barrett {
    fn residue32(&self, x: u32) -> u32 {
        // `residue64()`では商を近似するために、`hi`を利用していた。
        // 剰余を計算するには`lo`を使えば良い。
        let lo = self.magic.wrapping_mul(x as u64);
        // `lo * modulus + 0`の上位64bit
        lo.carrying_mul(self.modulus, 0).1 as u32
    }

    fn can_divide(&self, x: u32) -> bool {
        let lo = self.magic.wrapping_mul(x as u64);
        // 1除算では`magic`がオーバーフローしているので、
        // `lo <= self.magic.wrapping_sub(1)`とする。
        lo < self.magic
    }
}

Barret乗算のここがちょっと……

作業メモリが大きい

Barret乗算乗算はu32 * u32 % u32の計算をするために、96bitのスペースを利用します。そのため、32bit環境では多倍長の計算をせざるを得ません。一方、Montgomery乗算は64bitのスペースで良いため、mul_high()を利用して効率的な実装が可能です。64bit環境でu64 * u64 % u64をするときも同じです。

次のような使い分けが可能です。なお、数値は64bit環境での値です。

手法 特徴
Barrett乗算 [2, 2^{32}) 偶数を法にとれる
Plantard乗算 \lesssim 2^{31.3}の奇数 速い
Montgomery乗算 u64の奇数 大きな法をとれる

乗法逆元の計算に除算が必要

Montgomery乗算やPlantard乗算では法が奇数なので、2の乗法逆元が存在します。このため、拡張Binary GCDアルゴリズムを高速化できます[2]。Barrett乗算では偶数の法が許されるので、このような高速化はできません。拡張Binary GCDは比較的遅く、拡張Euclid互除法の方が速い傾向にあります。

まとめ

作業メモリの大きさが致命的ですが、Barrett乗算は美しいアルゴリズムです。Lemireの高速除算アルゴリズムとのマジックナンバーの一致に気付いたときは、感動してしまいました。

最後に少しだけ一般論を扱います。u{P} % u{Q}の除算で、マジックナンバーを\lceil 2^{R} / M\rceilとします。このとき、商を高々1の誤差で求めるにはR = \max(P, 2Q)、剰余を厳密に求めるにはR = P + Qが必要最小限です。P \ge Qという(自然な)仮定をおくと、剰余のマジックナンバーが商のそれよりも大きくなります。とくに、Q = Pのときは1.5倍になります。ところで、商を誤差なしで求めるにもR = P + Qが必要です[3]商の誤差は修正が効くので精度を妥協してもよいという訳です。

脚注
  1. Plantard乗算で入力を規格化するためにLemireのアルゴリズムを利用するなら追加で1つ必要。 ↩︎

  2. Euclid互除法に対するBinary GCDと同程度の高速化が期待されます。 ↩︎

  3. コメント欄を参照してください。 ↩︎

Discussion

Mizar/みざーMizar/みざー

Lemireの高速剰余アルゴリズムですが、被除数を32bit とする範囲では商を求めるための正規化は必要ないと思います。恐らくこのように言い換えられるでしょうか。

  • 整数 k,d,n について、 1 \le k,\quad 1 \le d < 2^k,\quad 0 \le n < 2^k とする。
  • B := 2^{2k},\quad c := \lceil B / d \rceil,\quad \mathrm{low}(n) := (cn) \bmod B とおく。
  • Euclid除法により n = qd + r,\quad q,r \in \mathbb{Z},\quad 0 \le r < d とおく。

このとき

\left\lfloor \frac{cn}{B} \right\rfloor = q = \left\lfloor \frac{n}{d} \right\rfloor, \qquad \left\lfloor \frac{\mathrm{low}(n)d}{B} \right\rfloor = r = n \bmod d

が成り立つ。

この主張は、 k=32,\ B=2^{64} のとき次の u32 用実装に対応する。

/// returns (n / D, n % D)
#[target_feature(enable="bmi2")]
const fn divmod_u32<const D: u32>(n: u32) -> (u32, u32) {
    assert!(D > 0);
    if D == 1 {
        // ceil(2^64 / D) は D == 1 のときオーバーフローするため、特例で処理
        return (n, 0);
    } else {
        // c = ceil(2^64 / D) を 64 bit で計算する
        let c = u64::MAX / (D as u64) + 1;
        // t = cn を 128 bit で計算する
        let t = (c as u128) * (n as u128);
        // t の上位 64 bit が商 q に等しい
        let q = (t >> 64) as u32;
        // t の下位 64 bit が low(n) に等しい
        let low = t as u64;
        // r = floor(low(n) * D / 2^64) を計算する
        let r = (((low as u128) * (D as u128)) >> u64::BITS) as u32;
        // q と r を返す
        (q, r)
    }
}

// examples

#[unsafe(no_mangle)]
#[target_feature(enable="bmi2")]
pub const fn divmod_7u32(n: u32) -> (u32, u32) {
    divmod_u32::<7>(n)
}

#[unsafe(no_mangle)]
#[target_feature(enable="bmi2")]
pub const fn div_7u32(n: u32) -> u32 {
    divmod_7u32(n).0
}

#[unsafe(no_mangle)]
#[target_feature(enable="bmi2")]
pub const fn mod_7u32(n: u32) -> u32 {
    divmod_7u32(n).1
}

https://rust.godbolt.org/z/jaKnqqod4

div_7u32:
        mov     edx, edi
        movabs  rax, 2635249153387078803
        mulx    rax, rax, rax
        ret

divmod_7u32:
        mov     edx, edi
        movabs  rax, 2635249153387078803
        mulx    rax, rdx, rax
        mov     ecx, 7
        mulx    rdx, rdx, rcx
        ret

mod_7u32:
        mov     eax, edi
        movabs  rdx, 2635249153387078803
        imul    rdx, rax
        mov     eax, 7
        mulx    rax, rax, rax
        ret
qdot3qdot3

コメントありがとうございます。Barrett乗算の性質を活かした良い設計ですね。商を求めると剰余がタダになるとは。 剰余を求めると商がタダでした。


投稿されたコードに1か所だけバグがあります。商を誤差なしで求める条件は次の通りです。

2^{R - Q} \ge 2^P + 2^Q

「まとめ」より

P = Q = 32のとき、R = P + Q + 1 = 65で十分です。ここで、マジックナンバーについてM \ge 3 \Leftrightarrow \bar{M} = \lceil 2^{65}/M \rceil< 2^{64}が成り立ちます。したがって、M = 1, 2の場合を特別扱いすれば、ワードサイズにやさしくなります。

頂いたリンクを改変しナイーブな除算を追加しました。コンパイラに勝てるようです。


Lemireの方法で剰余は誤差なしで求め、Barrettの方法で商を高々1の誤差で求めるのが記事の結論です。商の誤差は剰余を求める際にアンダーフローしたかどうかで識別できますが、剰余の誤差を求めるのは商を求めてみる必要があります。つまり、剰余の誤差はコスト面で許容できません。

qdot3qdot3

異なる手法でも近いコードが出せます。

Mizar/みざーMizar/みざー

除数が 2 の場合の反例はどのようなものがあるでしょうか。おそらく、 2^{R-Q} \ge 2^P + 2^Q は商や余りを正しく計算できるかどうか示す、鋭い条件ではないと思われます。

余りに関してはまだ整理をしていませんが、商に関しては以下のようなコードで K < ⌊⌈β/D⌉*x/β⌋ - ⌊x/D⌋ を満たす最小の非負整数 x を求めることができるかと思います。

def xmin_a1_upper(d: int, b: int, k: int) -> int | None:
    """
    上側最小違反点を返す。
    d は除数 D、b は近似除数 β、k は誤差の閾値 K を表す。
    すなわち
        min { x ∈ ℕ | K < ⌊⌈β/D⌉*x/β⌋ - ⌊x/D⌋ }
    を返す。存在しなければ None。
    """
    assert d > 0 and b > 0 and k >= 0
    m = (b + d - 1) // d
    delta = m * d - b
    if delta == 0:
        return None
    q = (k * b + m - 1) // delta
    r = ((k + 1) * b - q * delta + m - 1) // m
    return q * d + r
qdot3qdot3

正確に評価しました。

\begin{align*} B D + C \bar{M} &= \frac{A - C}{M} (M \bar{M} - 2^R) + C \bar{M} \\ &= (A - \cancel{C}) \bar{M} - \frac{(A - C) 2^R}{M} + \cancel{C \bar{M}} \\ &= \frac{A}{M} (M \bar{M} - 2^R) + \frac{C}{M} 2^R \\ &= \frac{D}{M} A + \frac{C}{M} 2^R \end{align*}

これが2^Rよりも小さければよいので、

2^R > \frac{D}{M - C} A \le D A < 2^{P + Q}

ご指摘通り、R = P + Qで良いです。

Mizar/みざーMizar/みざー

もう少し分かりやすそうな命題も出しておきます。

命題

正の整数 D,\beta に対し

M^+ := \left\lceil \frac{\beta}{D}\right\rceil

とおく。整数 x

0\le x<M^+

を満たすとき、次の等式が成り立つ。

\left\lfloor\frac{M^+x}{\beta}\right\rfloor=\left\lfloor\frac{x}{D}\right\rfloor

証明

M^+=\left\lceil \frac{\beta}{D}\right\rceil

より、天井関数の基本性質から

0 \le M^+ - \frac{\beta}{D} < 1

である。

ここで x

0\le x<M^+

を満たすとする。x は整数なので

x \le M^+-1

であり、また

M^+-1 < \frac{\beta}{D}

であるから

x < \frac{\beta}{D}

を得る。したがって

\frac{x}{\beta}<\frac{1}{D}

である。

さて

q:=\left\lfloor \frac{x}{D}\right\rfloor

とおくと、ある整数 r が存在して

x=qD+r,\qquad 0\le r<D

と書ける。よって

\frac{x}{D}=q+\frac{r}{D}

である。

一方、

\frac{M^+x}{\beta}-\frac{x}{D}=\frac{x}{\beta}\left(M^+-\frac{\beta}{D}\right)

なので、上で得た不等式より

0\le\frac{M^+x}{\beta}-\frac{x}{D}<\frac{x}{\beta}<\frac{1}{D}.

したがって

\frac{M^+x}{\beta}<\frac{x}{D}+\frac{1}{D}=q+\frac{r}{D}+\frac{1}{D}\le q+1,

また

\frac{M^+x}{\beta}\ge \frac{x}{D}=q+\frac{r}{D}\ge q.

ゆえに

q \le \frac{M^+x}{\beta} < q+1.

したがって床関数をとれば

\left\lfloor \frac{M^+x}{\beta}\right\rfloor=q=\left\lfloor \frac{x}{D}\right\rfloor.

以上より

0\le x<M^+\Longrightarrow\left\lfloor\frac{M^+x}{\beta}\right\rfloor=\left\lfloor\frac{x}{D}\right\rfloor

が示された。

qdot3qdot3

1除算対応v2

1除算でpanicはfootgunなので。

  • M = 1ならmask = 0M > 1ならmask = !0とする。最後にmaskをかける。パイプライン処理ならOK。
  • Result<Self, E>を返し、Eは0除算と1除算をバリアントにもつ。入力を殺す必要があるかもしれないが、重たい処理をスキップできるのはうれしい。