📝

繰り返し二乗法 - 高速な累乗計算 -

に公開

はじめに

大学の授業で扱った関数型プログラミングの授業で教授に紹介された指数関数の高速な計算方法が面白いなと思ったので競技プログラミング用に実装しました. それを共有していこうと思います. 使用する言語はC++です.

問題設定

自然数n, pについて, n^pを998244353で割った余りを求めなさい.

愚直にやっていたらpが大きい時にオーバーフローしてしまったり, p=10^{10}とかになった時にはそもそも実行時間が間に合わないかもしれない.
そこで新たな方法を紹介する.
計算時間はO(log_2 p)である. 個人的には衝撃的に早くて驚いた.

アルゴリズム

先にイメージだけ紹介しておく. 簡単のためにmodをとらないで行う.
3^5を計算するにあたり

5_{(10)}=101_{(2)}
という進数変換を行う.
つまり

3^{5_{(10)}}=3^{101_{(2)}}

という状態で計算を行う. あとはbitごとに指数を見て行って1になっている時にかけていくと完成する.
計算を実際に追っていく. 答えを

int res = 1;

用意しておく.
まず2^0(=1)のビットが1なので, res3^1をかける.

res = res * 3;    // resは3になる. 

次に2^1(=2)のビットは0なので何もしない.
(説明を読み終わったあとなら2^0(=1)をかけて何も起こらなかったと思えるだろう.)
最後に2^2(=4)のビットが1なので, res3^4をかける.

res = res * 81;    // resは243になる. 

さてこれの何が良いかというと, 実はresにかけ合わせる値を用意するのが簡単ということがある.
かける値は(かけない可能性があるのはおいておいて)
3^1, 3^2, 3^4, 3^8...
と3の2の冪乗になるので
3, 9, 81, 6561
となっていって前の数を二乗すると次の数を得ることができる.

実装(modなし)

一旦modは無視して実装したものがこちら↓↓

int mod_pow(int n, int p, int mod){
    int res = 1;
    int pw = n;
    while(p > 0){
        if(p & 1){
            res = res * pw;
        }
        pw = pw * pw;
        p = p >> 1;
    }
    return res;
}

関数の中身を1要素ずつ解釈していく.

int mod_pow(int n, int p, int mod){
}

nをp乗した値を整数型で返す関数である.

int res = 1;
int pw = n;

結果を格納する変数としてresを初期値を1としている. 乗法の単位元だからである.
ビットが立っている時に掛け合わせていくpwもnで初期化する.

    while(p > 0){
    }

指数部が0より大きい限り繰り返す. 10進数を2進数に変換する手順を追っているような感覚である.

       if(p & 1){
            res = res * pw;
        }

p & 1ではpの最下位のビットが1になっている時にのみ1(つまりtrue)を返すのでこの時には二進数で書いた時にiビット目が1であり, n^{2^i}をかける必要がありこれがpwに当たる.
そして

        pw = pw * pw;

でpwを二乗する. (n^{2^i}*n^{2^i}=n^{2^i+2^i}=n^{2^{i+1}}というイメージ)

        p = p >> 1;

pを右ビットシフトすることでif文のところで1と&演算をするビットを1つ(左に)ずらしている.

        p = p / 2;

とやっていることは同じ.

    return res;

で最後にresを返す.

\underline{今の理論はnが負でも破綻しないのでこの関数はn<0にも適用することができる}.

いまいち分からなかった人向けの具体例

先程の例で3^5を計算する.

int mod_pow(int n(3), int p(5), int mod){

n,pが入ってきて,

    int res = 1;
    int pw = n(3);

1周目のwhileループでは

    while(p(5) > 0){
        if(p(5) & 1){
        // 5は二進数で101なので1との&演算で1(true)が返る.
            res(3) = res(1) * pw(3);
            resは 1 * 3 = 3となる. 
        }
        pw(9) = pw(3) * pw(3);
        // pw = 9になる. 
        p = p >> 1;
        // p = 2になる. 
    }

2周目のwhileループでは

    while(p(2) > 0){
        if(p(2) & 1){
        // 2は二進数で10なので一番下のbitが0だからfalse.
            res(27) = res(3) * pw(9);
            // これは起こらないのでres = 3のまま. 
        }
        pw(81) = pw(9) * pw(9);
        // pw = 81になる. 
        p = p >> 1;
        // p = 0になる. 
    }

3周目のwhileループでは

    while(p(1) > 0){
        if(p(1) & 1){
        // 1は二進数で1なので一番下のbitが1だからtrue.
            res(243) = res(3) * pw(81);
            // res = 243になる. 
        }
        pw(6561) = pw(81) * pw(81);
        // pw = 6561になる. 
        p = p >> 1;
        // p = 0になる. 
    }

で次はもうp == 0だからwhileの中に入れないからこれで終わり.

実装(完成版)

これにあとはmodの要素を付け足せば終わりである.

int mod_pow(int n, int p, int mod){
    n = n % mod;     
    if(n < 0){
        n = n + mod;
    }
    int res = 1;
    int pw=n;
    while(p > 0){
        if(p & 1){
            res = res * pw % mod;
        }
        pw = pw * pw % mod;
        p = p >> 1;
    }
    return res;
}

whileの中身の変更は剰余をとっているだけなので

    n = n % mod;     
    if(n < 0){
        n = n + mod;
    }

の部分だけ解説する. 最後に余りを求めるのだからnを最初からmodで割った余りで計算しても良い. 詳しくは高校数学のmodの性質なのでネットで調べてもらうと良いだろう. そして(nが負なことをあまり想定してはいないが)もしnが負ならn % modの値が負で返ってきてしまうので,もう一度足すことで0\le n<modの範囲にnを収めることができる.

負の数の剰余

言語の性質に依るからなんとも言えないがc++では
-5 % 3は実行すると, -2と帰ってくる.
これによって超たまに偶数の判定で

if(k % 2 == 1)

とするとkが奇数でも
-1 == 1
の判定でfalseが返ってこのif文に入れない時もある.
だから偶奇の判定は

if(k % 2 == 0){
}

でやることを推奨する.
あと競プロ用のマクロで私が愛用しているものがあるのだが

#define MOD(a, b) (((a) % (b)) + (b)) % (b)

こうすることでさっきの理論で1度正にしてからもう1度剰余を取るので必ず正の余りを得ることができる.

その他

"余りを取らない"かつ"pが負も許容"の時は最後に逆数を返せば良い.
また、
計算が高速な上に、誤差が出やすいdouble型でのpow演算を避けることができるので整数型の指数や余りの議論を同時にしたいときには積極的に採用していきたい.

例題

https://atcoder.jp/contests/abc178/tasks/abc178_c
https://atcoder.jp/contests/abc357/tasks/abc357_d
https://atcoder.jp/contests/abc152/tasks/abc152_e
https://leetcode.com/problems/powx-n/description/

最後に

おそらく標準ライブラリでも整数型なら同じ実装になっていると思います。 それ以外の時は気をつけて場合分けをしながらx^y = e^{y \ln(x)}で実装してると思います。 (多分)
わかりにくかったり、間違っている部分があったらぜひ教えてください!

Discussion