繰り返し二乗法 - 高速な累乗計算 -
はじめに
大学の授業で扱った関数型プログラミングの授業で教授に紹介された指数関数の高速な計算方法が面白いなと思ったので競技プログラミング用に実装しました. それを共有していこうと思います. 使用する言語はC++です.
問題設定
自然数
について, n, p を998244353で割った余りを求めなさい. n^p
愚直にやっていたら
そこで新たな方法を紹介する.
計算時間は
アルゴリズム
先にイメージだけ紹介しておく. 簡単のためにmodをとらないで行う.
つまり
という状態で計算を行う. あとはbitごとに指数を見て行って1になっている時にかけていくと完成する.
計算を実際に追っていく. 答えを
int res = 1;
用意しておく.
まず
res = res * 3; // resは3になる.
次に
(説明を読み終わったあとなら
最後に
res = res * 81; // resは243になる.
さてこれの何が良いかというと, 実はresにかけ合わせる値を用意するのが簡単ということがある.
かける値は(かけない可能性があるのはおいておいて)
と3の2の冪乗になるので
となっていって前の数を二乗すると次の数を得ることができる.
実装(modなし)
一旦modは無視して実装したものがこちら↓↓
int mod_pow(int n, int p, int mod){
int res = 1;
int pw = n;
while(p > 0){
if(p & 1){
res = res * pw;
}
pw = pw * pw;
p = p >> 1;
}
return res;
}
関数の中身を1要素ずつ解釈していく.
int mod_pow(int n, int p, int mod){
}
nをp乗した値を整数型で返す関数である.
int res = 1;
int pw = n;
結果を格納する変数としてresを初期値を1としている. 乗法の単位元だからである.
ビットが立っている時に掛け合わせていくpwもnで初期化する.
while(p > 0){
}
指数部が0より大きい限り繰り返す. 10進数を2進数に変換する手順を追っているような感覚である.
if(p & 1){
res = res * pw;
}
p & 1ではpの最下位のビットが1になっている時にのみ1(つまりtrue)を返すのでこの時には二進数で書いた時に
そして
pw = pw * pw;
でpwを二乗する. (
p = p >> 1;
pを右ビットシフトすることでif文のところで1と&演算をするビットを1つ(左に)ずらしている.
p = p / 2;
とやっていることは同じ.
return res;
で最後にresを返す.
いまいち分からなかった人向けの具体例
先程の例で
int mod_pow(int n(3), int p(5), int mod){
で
int res = 1;
int pw = n(3);
1周目のwhileループでは
while(p(5) > 0){
if(p(5) & 1){
// 5は二進数で101なので1との&演算で1(true)が返る.
res(3) = res(1) * pw(3);
resは 1 * 3 = 3となる.
}
pw(9) = pw(3) * pw(3);
// pw = 9になる.
p = p >> 1;
// p = 2になる.
}
2周目のwhileループでは
while(p(2) > 0){
if(p(2) & 1){
// 2は二進数で10なので一番下のbitが0だからfalse.
res(27) = res(3) * pw(9);
// これは起こらないのでres = 3のまま.
}
pw(81) = pw(9) * pw(9);
// pw = 81になる.
p = p >> 1;
// p = 0になる.
}
3周目のwhileループでは
while(p(1) > 0){
if(p(1) & 1){
// 1は二進数で1なので一番下のbitが1だからtrue.
res(243) = res(3) * pw(81);
// res = 243になる.
}
pw(6561) = pw(81) * pw(81);
// pw = 6561になる.
p = p >> 1;
// p = 0になる.
}
で次はもうp == 0だからwhileの中に入れないからこれで終わり.
実装(完成版)
これにあとはmodの要素を付け足せば終わりである.
int mod_pow(int n, int p, int mod){
n = n % mod;
if(n < 0){
n = n + mod;
}
int res = 1;
int pw=n;
while(p > 0){
if(p & 1){
res = res * pw % mod;
}
pw = pw * pw % mod;
p = p >> 1;
}
return res;
}
whileの中身の変更は剰余をとっているだけなので
n = n % mod;
if(n < 0){
n = n + mod;
}
の部分だけ解説する. 最後に余りを求めるのだからnを最初からmodで割った余りで計算しても良い. 詳しくは高校数学のmodの性質なのでネットで調べてもらうと良いだろう. そして(nが負なことをあまり想定してはいないが)もしnが負ならn % modの値が負で返ってきてしまうので,もう一度足すことで
負の数の剰余
言語の性質に依るからなんとも言えないがc++では
-5 % 3は実行すると, -2と帰ってくる.
これによって超たまに偶数の判定で
if(k % 2 == 1)
とするとkが奇数でも
-1 == 1
の判定でfalseが返ってこのif文に入れない時もある.
だから偶奇の判定は
if(k % 2 == 0){
}
でやることを推奨する.
あと競プロ用のマクロで私が愛用しているものがあるのだが
#define MOD(a, b) (((a) % (b)) + (b)) % (b)
こうすることでさっきの理論で1度正にしてからもう1度剰余を取るので必ず正の余りを得ることができる.
その他
"余りを取らない"かつ"
また、
計算が高速な上に、誤差が出やすいdouble型でのpow演算を避けることができるので整数型の指数や余りの議論を同時にしたいときには積極的に採用していきたい.
例題
最後に
おそらく標準ライブラリでも整数型なら同じ実装になっていると思います。 それ以外の時は気をつけて場合分けをしながら
わかりにくかったり、間違っている部分があったらぜひ教えてください!
Discussion