👋

ABC186/E - Throne

2020/12/22に公開

問題

自然数 N,S,K が与えられる.
次を満たす最小の自然数 m があるかどうか判定し, あるならそれを答えよ.

S + mK = 0 \bmod N

ただし N10^9 程度で, O(N) 未満で計算して欲しい.

解法

公式では互除法で解いてる.
今からここで書こうとするのは

https://twitter.com/kyopro_friends/status/1341216644727676928

これが全て.

答えが N 以下であることについて

a_m = (S+mK) \mod N として数列

a_0, a_1, \ldots

を考える.
問題はこれが初めて 0 になるものを探せというものである.
もし i<j について a_i = a_j となるなら, a_{i+1} = a_{j+1} になることから, この数列の値はループになることが分かる.
しかもとり得る値が N 個の自然数のいずれかということから, 鳩の巣原理を適用することで, a_0 から a_N までを考えると必ずそこに重複が含まれる.
従って 0 を取る項があるなら必ずこの中に含まれる.

Baby-Step Giant-Step

自然数 m は今 0 以上 N 以下に限ってよいが, これは小さいものから順に次のように, (i \sqrt{N} + j) という和に分解出来る.

  • 0 \sqrt{N} + 0
  • 0 \sqrt{N} + 1
  • 0 \sqrt{N} + 2
  • \vdots
  • 0 \sqrt{N} + \sqrt{N},
  • 1 \sqrt{N} + 0
  • 1 \sqrt{N} + 1
  • 1 \sqrt{N} + 2
  • \vdots
  • 1 \sqrt{N} + \sqrt{N},
  • \vdots
  • (\sqrt{N}-1) \sqrt{N} + 0
  • (\sqrt{N}-1) \sqrt{N} + 1
  • (\sqrt{N}-1) \sqrt{N} + 2
  • \vdots
  • (\sqrt{N}-1) \sqrt{N} + \sqrt{N}.

ただしここで \sqrt{m} は自然数の範囲の平方根であり, ちょうど k^2 \geq m を満たす最小の自然数 k のこととする.

よく見ると上の式は重複が含まれてるんだけど, 小さい順に網羅して調べることが大事なので, そこはとやかく言わないことにする.

さて (S + mK) \bmod N = 0 を調べるのには, (i \sqrt{N} + j) を以て,

(S + i \sqrt{N} K + j K) \bmod N = 0

を調べれば良い.

擬似コードで書くと

r = sqrt(m);
for i in range(0, r) {
  for j in range(0, r) {
    if (S + i * r * K + j * K) % N == 0 {
      yield i * r + j;
    }
  }
}

この range は両端を含む範囲だとする.
この二重ループをそのまま実行すると, 結局 O(N) だが, j についてだけ次のように先にまとめておく.

d = {}  // set
for j in range(0, r) {
  d.insert((j * K) % N);
}

こうしておけば, 各 i について

S + i \sqrt{N} K + x = 0 \mod N

を満たす xd にあるかを問い合わせれば良い.
どうせ mod を取るので, d には予め mod を取った値で入れておくのが良い.
式変形をして,
x = (-S - i \sqrt{N} K) \bmod N

と出来るので, これが d に含まれるかを見ればいい.
d をハッシュテーブルとか二分探索木で持っておけばここは高速に可能.

また, いま問題になってるのは, 存在するかどうかだけではなくて, 存在する場合のその添字 m も求めることだったので,
d はただの集合じゃなくて, それを実現する i の値まで持っておくとよい.
HashMap とか BTreeMap とかそういう辞書にする.

d = {}  // dict
for j in range(0, r) {
  x = (j * K) % N;
  if x not in d {
    d[x] = j;
  }
}

d[x](jK) \bmod N を満たす最小の自然数 j を表す.

for i in range(0, r) {
  y = (S + i * r * K) % N;
  x = N - y;  // y + x = 0 を満たす x
  if x in d {
    m = i * r + d[x];
    print(m);
    return;
  }
}

// 見つからなかった
print(-1);

これで O(\sqrt{N}) 程度(d の問い合わせの時間がここにさらに掛かる).

Discussion