📝

[C++]メモ化再帰を実装してみる

2022/10/31に公開

先日のABC275にて、メモ化再帰というものを学びました。

ワード自体は聞き覚えがあったのですが、コードレベルで知識があったかというとそうでもなかったので、今回C++で実装してみた次第です。

対象

C++でコーディングを行うため、読める方を対象としています。
また、再帰関数の知識を用います。

再帰関数の理解には、AtCoderにて掲載されているAPG4bをおすすめします。
「ふわっとした知識はあるけど、いまいち言語化はできないよ~」という場合は、同記事の「再帰関数の性質」から目を通してみるとよいです🐇

https://atcoder.jp/contests/APG4b/tasks/APG4b_v

メモ化とは

プログラムの高速化のための最適化技法の一種であり、サブルーチン呼び出しの結果を後で再利用するために保持し、そのサブルーチン(関数)の呼び出し毎の再計算を防ぐ手法である。

https://ja.wikipedia.org/wiki/メモ化

やさしい言葉に置き換えると、「AならばB、CならばD、...」と続くとき、次にAやCが呼ばれたときにわざわざ計算しなくていいように、配列などのコンテナに引数と結果の組を保存しておくことといった感じです。
同じ処理を省くことで、処理速度が大幅に上がることもあります。

実装してみる

せっかくなので、先日行われたABC275の問題を引用してみます。

ABC275-D. Yet Another Recursive Function

https://atcoder.jp/contests/abc275/tasks/abc275_d

f(n) = f([n/2]) + f([n/3])というルールで数列を作ったとき、前からn番目にある数値を求める問題のようです。
ここで、[k]という記号は「kの小数点以下を切り捨てたもの」という意味になります(詳しくはガウス記号で検索してみてください)。

まずは再帰関数のみで

inline int in_int() {int x; cin >> x; return x;} // 数値の標準入力
using ull = unsigned long long; // 型命名を略しています

ull f(ull n) {
  // ベースケース
  if(n == 0) return 1;

  // 再帰ステップ
  return f(n/2) + f(n/3);
}

void Main() {
  ull n = in_int();
  cout << f(n) << endl;
}
  1. 与えられた式f(n) = f([n/2]) + f([n/3])を再帰ステップに設定
  2. ↑の式は必ずn=0に落ち着くので、n=0をベースケースとして設定

D問題にしては意外と単純...と思っていましたが、そう甘くはありません。

よく問題を見てみると、入力される数値Nの範囲は0 \leqq N \leqq 10^{18}、かなり大きい数をとることがわかります。
実際に上のコードにN=10^{18}を計算させると、処理時間[1]は以下の通り。

42804ms、つまり約42秒の処理時間を要しています。
問題の制限時間は2秒であるため、余裕で間に合いません。

ここで、メモ化を利用してみます。

メモ化を使って実装してみる

inline int in_int() {int x; cin >> x; return x;} // 数値の標準入力
using ull = unsigned long long; // 型命名を略しています

+ map<ull, ull> memo;

ull f(ull n) {
  // ベースケース
  if(n == 0) return 1;

  // 再帰ステップ
-   return f(n/2) + f(n/3);
+   if(memo[n] != 0) return memo[n]; // メモがあったら返す
+   ull ans = f(n/2) + f(n/3);
+   memo[n] = ans;
+   return ans;
}

void Main() {
  ull n = in_int();
  cout << f(n) << endl;
}

メモ用の変数としてはmap型を採用しています。
キーとバリューの一組で扱う、いわゆる辞書のような型ですね。

注意すべき点としては、メモに値が存在する時の処理です。

if(memo[n] != 0) return memo[n]; // メモがあったら返す

このmemo[n] != 0という処理ですが、当然memo[n]が存在しない場合があります。
その際はエラーになるのか?...というと、そうでもありません。

map型の[]演算子の優秀な点として、存在しないキーが選択されたら自動的にデフォルト値を代入するという特徴を持ちます。
この場合のデフォルト値は(0, 0)となるため、バリューが0 = 値が存在しないということを示していることがわかります。[2]

実際にN=10^{18}を与えて実行してみると...

309ms、わずか0.3秒で処理を終了させています。すさまじい。

おわりに

大雑把ではありますが、メモ化再帰の解説、ならびに問題の解説でした。
まだ理解の浅い部分もありますので、よければマサカリお待ちしております。

脚注
  1. 処理時間の計測にはclock()関数を用いています。 ↩︎

  2. この場合は「バリューのとりうる値が1以上」ということがわかっているためこのような処理をしています。問題によってはバリューが0以下の値もとることがあるため、十分に注意してください。 ↩︎

Discussion