🦀

RustでManacher Algorithmを実装し、回文検索をO(n)で解く

commits8 min read

成果物

https://gist.github.com/ulwlu/c41509710ea99c1d3d3d1a6056961bbd

Manacher Algorithm とは

通常総当りでは O(n^3)、中心から線形探索すれば O(n^2)かかる回文検索問題を、O(n)線形時間だけで解くアルゴリズム(wiki

この動画が一番わかりやすいです。20 以上の記事や動画を見た中で一番なるほど、と感動したので記事にします。

Manacher Algorithm の方針

  1. abaxabaという文字列がある
  2. 最初から探索開始
  3. a->1, b->3, a->1, x->7
  4. 次に a を探索するのか?しません。ここがポイント。
  5. [理由] a は x を中心とする回文の範囲内。この時 x の左側と右側は、x の回文の範囲内なら全く同じなはず。
  6. なので、左側の a の 1 をコピーして終わり。
  7. 次に b は左側 3 なのでこれも 3・・・と思いきや、これはコピーしません。これが第 2 のポイント。
  8. [理由]b から長さ 3 の範囲は、x を中心とする回文の範囲の右端に到達してしまいます。これではもしabaxaba - axabaという5文字が後に続いてたら b=3 は誤りです。なので右端に到達する値を保持していたら今度はその文字から検索する必要があるので、b は探索します。
  9. 1-8 を繰り返し、右端が列終端に到達したら、以降の文字列は確実に中心点からの半径より短い = 最大長にはなりえないので修了。

何故 Manacher Algorithm は O(n)なのか

パッと見、最悪ケースだと全ての文字を探索して全ての文字について n 回探しそうだから O(n^2)では?と思われます(海外でも結構質問記事が多い参考1 , 参考 2

ポイントは

  • 探索を開始するのは、現在回文が保証されている範囲の右端を超えたときのみ
  • 回文が保証されている点は探索せずスキップできる

つまり回文が保証されている範囲、中心点、右端は必ず単調増加しており探索する index はただ右に伸びていき重複しない事がわかります。混乱の元となるのは「左側も探索しないといけない」という事実なのですが、これは中心点からlist[center-radius] == list[center+radius]という探索を行う事で計算量左端index-center = 右端index-centerです。

各比較操作が成功するたびに右に1段階前進しており、中心点も右端も絶対に縮小する事がないので探索は n 回だけしか行なっておらず、計算量は O(n)で完了します。

Manacher Algorithm の理解図

nya

thanks: draw.io

Manacher Algorithm の実装

上記のアルゴリズム理解から、

  • その index を中心とする回文の長さを保持する配列
  • 現在 focus してる回文の中心点
  • 現在 focus してる回文の右端

があれば設計できるなと気持ち覚えておきます。

1: まず偶数時をケアする

まず中央からの線形探索ができるように、偶数時のケア用のダミー文字を挿入します。
「aba」だと b から両端を見れば回文だとわかりますが
「aa」だと探索の始点がありません。なので「#a#a#」という風にダミー文字をいれて、#から aa を見つけられるようにします。

// MEMO: We need to detect odd palindrome as well,
// therefore, inserting dummy string so that
// we can find a pair with dummy center character.
let mut chars: Vec<char> = Vec::with_capacity(s.len() * 2 + 1);
for c in s.chars() {
  chars.push('#');
  chars.push(c);
}
chars.push('#');

2: 回文になっているかを探索する処理を実装する

// List: storing the length of palindrome at each index of string
let mut length_of_palindrome = vec![1usize; chars.len()];
// Integer: Current checking palindrome's center index
let mut current_center: usize = 0;
// Integer: Right edge index existing the radius away from current center
let mut right_from_current_center: usize = 0;

for i in 0..chars.len() {

//....

  // Integer: Current radius from checking index
  // If it's copied from left side and more than 1,
  // it means it's ensured so you don't need to check inside radius.
  let mut radius: usize = (length_of_palindrome[i] - 1) / 2;
  radius = radius + 1;
  // 2: Checking palindrome.
  // Need to care about overflow usize.
  while i >= radius && i + radius <= chars.len() - 1 && chars[i - radius] == chars[i + radius]
  {
    length_of_palindrome[i] = length_of_palindrome[i] + 2;
    radius = radius + 1;
  }
}

最初のlet mut radius: usize = (length_of_palindrome[i] - 1) / 2;が少しわかりづらいかも知れません。コメントにある通り、もしその地点の値が左側からコピーしてきた「3,5,10」などだったとします。
これは current_center を中心とする保証された回文の長さです。という事はその中は探索しなくてよくて、その次1個先から見れば良いことになります(この1個先が、1行次のradius = radius + 1)。
ですので効率化の手段になりえるので採用してます。

while の直後の「i >= radius」、本当は「i - radius >= 0」と後々の処理とコンテキスト合わせたいんですけど、インデックスは usize なので overflow panic を起こしてしまうためこうしてます(負になる可能性がある)。
特に制約上問題ないことが自明な条件なときに、Rust だとオーバーフローが起きる可能性のある部分をコンパイル時に Panic 出してくれるのが逆にムムムとなります(実務ではとてもありがたい神機能ですが)。
本来 overflow するような場合はその最大値にまとめてくれる saturation とか、切り捨ててくれる wrapping_add/sub などがあるのですが、usize が負になるケースはそういうメソッドでは対応できないのでこうしてます、ですが・・・・・・ちょっと見通しが悪いので、もっとスマートに書けれたらいいなと思ってます。
もし誰かご存知でしたら教えて下さい。

3: 探索しないで左側からコピーして効率化する処理を実装する

for i in 0..chars.len() {
  // 1: Check if we are looking at right side of palindrome.
  if right_from_current_center > i && i > current_center {
    // 1-1: If so copy from the left side of palindrome.
    // If the value + index exceeds the right edge index, we should cut and check palindrome later #3.
    length_of_palindrome[i] = std::cmp::min(
      right_from_current_center - i,
      length_of_palindrome[2 * current_center - i],
    );
    // 1-2: Move the checking palindrome to new index if it exceeds the right edge.
    if length_of_palindrome[i] + i >= right_from_current_center {
      current_center = i;
      right_from_current_center = length_of_palindrome[i] + i;
      // 1-3: If radius exceeds the end of list, it means checking is over.
      // You will never get the larger value because the string will get only shorter.
      if right_from_current_center >= chars.len() - 1 {
        break;
      }
    } else {
      // 1-4: If the checking index doesn't exceeds the right edge,
      // it means the length is just as same as the left side.
      // You don't need to check anymore.
      continue;
    }
  }

// ....

}

コメントにむちゃくちゃメモをしているのでほぼ説明する必要がないと思います。
もし所見のときにパット見わかりづらいところを上げるとすればright_from_current_center - iの部分でしょうか。これは現在チェックしている i に、本来左側から 10 がコピーできるとします。しかしその地点が右端だった場合、回文が保証されているのは右端から i のインデックスまでの距離、すなわち1だけです。
回文が保証されてる範囲のみにそろえて後で回文探索を開始したいため、min をかけています。

4: テスト

ここまで実装できたので試しにテストを入れてみましょう。

#[test]
fn test_longest_palindrome() {
  assert_eq!(longest_palindrome("babad".to_string()), "aba".to_string());
  assert_eq!(longest_palindrome("cbbd".to_string()), "bb".to_string());
  assert_eq!(longest_palindrome("a".to_string()), "a".to_string());
  assert_eq!(longest_palindrome("ac".to_string()), "c".to_string());
}

nya

期待通り動いてますね。
LeetCode の回文問題でもやってみましょう。

nya

合格してますね。これにて実装は終わりです。

あらためて成果物

https://gist.github.com/ulwlu/c41509710ea99c1d3d3d1a6056961bbd

まとめ

回文検索という限定されたケースではあまり活用道はないですが、重複を察知してその部分の処理を無視するという発想はかなり使えそうだと思いました。
縮小した範囲も検索しないといけない・・・そうすると計算量が爆発する・・・という時に、こういう方法があると思い出せれば効率化の糸口が掴めそうです。

ちなみに Python だと以下のように半分のコード量でできます。

def longestPalindrome(self, s: str) -> str:
  if len(s) == 1:
    return s
  # Need to care about odd
  newstr = '#' + '#'.join(s) + '#'
  lengthOfPalindrome = [1] * len(newstr)
  currentCenter = 0
  rightFromCurrentCenter = 0
  for i in range(0, len(newstr)):
    # Check if this is mirror side
    if rightFromCurrentCenter > i and i > currentCenter:
      # Cut if it exceeds right edge.
      lengthOfPalindrome[i] = min(rightFromCurrentCenter - i, lengthOfPalindrome[2 * currentCenter - i])
      # If this is mirror side, and it exceeds the edge, you should check palindrome. otherwise you can ignore.
      if lengthOfPalindrome[i] + i >= rightFromCurrentCenter:
        currentCenter = i
        rightFromCurrentCenter = lengthOfPalindrome[i] + i
        if rightFromCurrentCenter >= len(newstr) - 1:
          # It means you reached the end, and you can't get any larger number.
          break;
      else:
        continu
    # Checking palindrome.
    radius = (lengthOfPalindrome[i] - 1) // 2
    radius += 1
    while (i - radius >= 0 and
         i + radius <= len(newstr) - 1 and
         newstr[i - radius] == newstr[i + radius]):
      lengthOfPalindrome[i] += 2
      radius +=
  maxLen, centerIndex = max((length, idx) for idx, length in enumerate(lengthOfPalindrome))
  radius = (maxLen - 1) // 2
  ans = newstr[centerIndex-radius:centerIndex+radius+1]
  return ans.replace("#", "")