RustでManacher Algorithmを実装し、回文検索をO(n)で解く
成果物
Manacher Algorithmとは
通常総当りではO(n^3)、中心から線形探索すればO(n^2)かかる回文検索問題を、O(n)線形時間だけで解くアルゴリズム(wiki)
この動画が一番わかりやすいです。20以上の記事や動画を見た中で一番なるほど、と感動したので記事にします。
Manacher Algorithmの方針
-
abaxaba
という文字列がある - 最初から探索開始
- a->1, b->3, a->1, x->7
- 次にaを探索するのか?しません。ここがポイント。
- [理由] aはxを中心とする回文の範囲内。この時xの左側と右側は、xの回文の範囲内なら全く同じなはず。
- なので、左側のaの1をコピーして終わり。
- 次にbは左側3なのでこれも3・・・と思いきや、これはコピーしません。これが第2のポイント。
- [理由]bから長さ3の範囲は、xを中心とする回文の範囲の右端に到達してしまいます。これではもし
abaxaba - axaba
という5文字が後に続いてたらb=3は誤りです。なので右端に到達する値を保持していたら今度はその文字から検索する必要があるので、bは探索します。 - 1-8を繰り返し、右端が列終端に到達したら、以降の文字列は確実に中心点からの半径より短い = 最大長にはなりえないので修了。
何故Manacher AlgorithmはO(n)なのか
パッと見、最悪ケースだと全ての文字を探索して全ての文字についてn回探しそうだからO(n^2)では?と思われます(海外でも結構質問記事が多い参考1 , 参考2)
ポイントは
- 探索を開始するのは、現在回文が保証されている範囲の右端を超えたときのみ
- 回文が保証されている点は探索せずスキップできる
つまり回文が保証されている範囲、中心点、右端は必ず単調増加しており探索するindexはただ右に伸びていき重複しない事がわかります。混乱の元となるのは「左側も探索しないといけない」という事実なのですが、これは中心点からlist[center-radius] == list[center+radius]
という探索を行う事で計算量左端index-center
= 右端index-center
です。
各比較操作が成功するたびに右に1段階前進しており、中心点も右端も絶対に縮小する事がないので探索はn回だけしか行なっておらず、計算量はO(n)で完了します。
Manacher Algorithmの理解図
thanks: draw.io
Manacher Algorithmの実装
上記のアルゴリズム理解から、
- そのindexを中心とする回文の長さを保持する配列
- 現在focusしてる回文の中心点
- 現在focusしてる回文の右端
があれば設計できるなと気持ち覚えておきます。
1: まず偶数時をケアする
まず中央からの線形探索ができるように、偶数時のケア用のダミー文字を挿入します。
「aba」だとbから両端を見れば回文だとわかりますが
「aa」だと探索の始点がありません。なので「#a#a#」という風にダミー文字をいれて、#からaaを見つけられるようにします。
// MEMO: We need to detect odd palindrome as well,
// therefore, inserting dummy string so that
// we can find a pair with dummy center character.
let mut chars: Vec<char> = Vec::with_capacity(s.len() * 2 + 1);
for c in s.chars() {
chars.push('#');
chars.push(c);
}
chars.push('#');
2: 回文になっているかを探索する処理を実装する
// List: storing the length of palindrome at each index of string
let mut length_of_palindrome = vec![1usize; chars.len()];
// Integer: Current checking palindrome's center index
let mut current_center: usize = 0;
// Integer: Right edge index existing the radius away from current center
let mut right_from_current_center: usize = 0;
for i in 0..chars.len() {
//....
// Integer: Current radius from checking index
// If it's copied from left side and more than 1,
// it means it's ensured so you don't need to check inside radius.
let mut radius: usize = (length_of_palindrome[i] - 1) / 2;
radius = radius + 1;
// 2: Checking palindrome.
// Need to care about overflow usize.
while i >= radius && i + radius <= chars.len() - 1 && chars[i - radius] == chars[i + radius]
{
length_of_palindrome[i] = length_of_palindrome[i] + 2;
radius = radius + 1;
}
}
最初のlet mut radius: usize = (length_of_palindrome[i] - 1) / 2;
が少しわかりづらいかも知れません。コメントにある通り、もしその地点の値が左側からコピーしてきた「3,5,10」などだったとします。
これはcurrent_centerを中心とする保証された回文の長さです。という事はその中は探索しなくてよくて、その次1個先から見れば良いことになります(この1個先が、1行次のradius = radius + 1
)。
ですので効率化の手段になりえるので採用してます。
whileの直後の「i >= radius」、本当は「i - radius >= 0」と後々の処理とコンテキスト合わせたいんですけど、インデックスはusizeなのでoverflow panicを起こしてしまうためこうしてます(負になる可能性がある)。
特に制約上問題ないことが自明な条件なときに、Rustだとオーバーフローが起きる可能性のある部分をコンパイル時にPanic出してくれるのが逆にムムムとなります(実務ではとてもありがたい神機能ですが)。
本来overflowするような場合はその最大値にまとめてくれるsaturationとか、切り捨ててくれるwrapping_add/subなどがあるのですが、usizeが負になるケースはそういうメソッドでは対応できないのでこうしてます、ですが・・・・・・ちょっと見通しが悪いので、もっとスマートに書けれたらいいなと思ってます。
もし誰かご存知でしたら教えて下さい。
3: 探索しないで左側からコピーして効率化する処理を実装する
for i in 0..chars.len() {
// 1: Check if we are looking at right side of palindrome.
if right_from_current_center > i && i > current_center {
// 1-1: If so copy from the left side of palindrome.
// If the value + index exceeds the right edge index, we should cut and check palindrome later #3.
length_of_palindrome[i] = std::cmp::min(
right_from_current_center - i,
length_of_palindrome[2 * current_center - i],
);
// 1-2: Move the checking palindrome to new index if it exceeds the right edge.
if length_of_palindrome[i] + i >= right_from_current_center {
current_center = i;
right_from_current_center = length_of_palindrome[i] + i;
// 1-3: If radius exceeds the end of list, it means checking is over.
// You will never get the larger value because the string will get only shorter.
if right_from_current_center >= chars.len() - 1 {
break;
}
} else {
// 1-4: If the checking index doesn't exceeds the right edge,
// it means the length is just as same as the left side.
// You don't need to check anymore.
continue;
}
}
// ....
}
コメントにむちゃくちゃメモをしているのでほぼ説明する必要がないと思います。
もし所見のときにパット見わかりづらいところを上げるとすればright_from_current_center - i
の部分でしょうか。これは現在チェックしているiに、本来左側から10がコピーできるとします。しかしその地点が右端だった場合、回文が保証されているのは右端からiのインデックスまでの距離、すなわち1だけです。
回文が保証されてる範囲のみにそろえて後で回文探索を開始したいため、minをかけています。
4: テスト
ここまで実装できたので試しにテストを入れてみましょう。
#[test]
fn test_longest_palindrome() {
assert_eq!(longest_palindrome("babad".to_string()), "aba".to_string());
assert_eq!(longest_palindrome("cbbd".to_string()), "bb".to_string());
assert_eq!(longest_palindrome("a".to_string()), "a".to_string());
assert_eq!(longest_palindrome("ac".to_string()), "c".to_string());
}
期待通り動いてますね。
LeetCodeの回文問題でもやってみましょう。
合格してますね。これにて実装は終わりです。
あらためて成果物
まとめ
回文検索という限定されたケースではあまり活用道はないですが、重複を察知してその部分の処理を無視するという発想はかなり使えそうだと思いました。
縮小した範囲も検索しないといけない・・・そうすると計算量が爆発する・・・という時に、こういう方法があると思い出せれば効率化の糸口が掴めそうです。
ちなみにPythonだと以下のように半分のコード量でできます。
def longestPalindrome(self, s: str) -> str:
if len(s) == 1:
return s
# Need to care about odd
newstr = '#' + '#'.join(s) + '#'
lengthOfPalindrome = [1] * len(newstr)
currentCenter = 0
rightFromCurrentCenter = 0
for i in range(0, len(newstr)):
# Check if this is mirror side
if rightFromCurrentCenter > i and i > currentCenter:
# Cut if it exceeds right edge.
lengthOfPalindrome[i] = min(rightFromCurrentCenter - i, lengthOfPalindrome[2 * currentCenter - i])
# If this is mirror side, and it exceeds the edge, you should check palindrome. otherwise you can ignore.
if lengthOfPalindrome[i] + i >= rightFromCurrentCenter:
currentCenter = i
rightFromCurrentCenter = lengthOfPalindrome[i] + i
if rightFromCurrentCenter >= len(newstr) - 1:
# It means you reached the end, and you can't get any larger number.
break;
else:
continu
# Checking palindrome.
radius = (lengthOfPalindrome[i] - 1) // 2
radius += 1
while (i - radius >= 0 and
i + radius <= len(newstr) - 1 and
newstr[i - radius] == newstr[i + radius]):
lengthOfPalindrome[i] += 2
radius +=
maxLen, centerIndex = max((length, idx) for idx, length in enumerate(lengthOfPalindrome))
radius = (maxLen - 1) // 2
ans = newstr[centerIndex-radius:centerIndex+radius+1]
return ans.replace("#", "")