Zenn
Closed3

AtCoder ARC148 F - 998244353 → 1000000007 回答例

Mizar/みざーMizar/みざー
  • xa(ya+za)mod  264x_a\leftarrow(y_a+z_a)\mod 2^{64}
  • xm(ym×zm)mod  264x_m\leftarrow(y_m\times z_m)\mod 2^{64}
  • xryrmod  998244353x_r\leftarrow y_r\mod 998244353

のような演算パターンの組み合わせで a×bmod  1000000007a\times b\mod 1000000007 を作る、モンゴメリ剰余乗算を上手く拡張して実装できるかという問題。

https://atcoder.jp/contests/arc148/tasks/arc148_f

以下は自分の19行の回答例。(既にユーザ解説で18行での解法も紹介されているのでそちらも参照

https://atcoder.jp/contests/arc148/submissions/34814433

  • N=1000000007N=1000000007
  • R=998244353R=998244353
  • Nmod  264=18446744072709551609-N\mod 2^{64}=18446744072709551609
  • N=R(N1mod  R)=4915446N'=R-(N^{-1}\mod R)=4915446
  • N1mod  R=993328907N^{-1}\mod R=993328907
  • R1mod  264=996491785301655553R^{-1}\mod 2^{64}=996491785301655553
  • R2=R2mod  N=320946142R_2=R^2\mod N=320946142
  • Rc=R2R1mod  264=8456992263029997534R_c=R_2*R^{-1}\mod 2^{64}=8456992263029997534
  • Q=264mod  R=932051910Q=2^{64}\mod R=932051910
  • P=(NQ)/Qmod  R=696235320P=(N-Q)/Q\mod R=696235320

1回目のモンゴメリ変換を
t(T+(((Tmod  R)Nmod  R)N))/R{0t<TR+N};t\leftarrow (T+(((T\mod R)*N'\mod R) * N))/R\quad\{0\le t\lt\frac{T}{R}+N\};
if t<N then return (t) else return (tN);\text{if }t\lt N\text{ then return }(t)\text{ else return }(t-N);
2回目のモンゴメリ変換を
t(T(((Tmod  R)N1mod  R)N))/R{N<t<TR};t\leftarrow (T-(((T\mod R)*N^{-1}\mod R) * N))/R\quad\{-N\lt t\lt\frac{T}{R}\};
if t<0 then return (t+N), else return (t);\text{if }t\lt 0\text{ then return }(t+N),\text{ else return }(t);
と2回目のモンゴメリ変換を変形して適用した場合に短縮できるかの試み。(出来てはいません)

Q=264mod  R=932051910Q=2^{64}\mod R=932051910 が偶数で法 2642^{64} に対する乗法逆元を持たないため、ステップ数短縮には今のところ繋がっていません。

/*

19
mul T A B
rem U T
mul U U 4915446
rem U U
mul U U 1000000007
add U U T
mul T U 8456992263029997534
rem U T
mul U U 993328907
rem U U
mul U U 18446744072709551609
add U U T
mul T U 996491785301655553
mul U T 998244353
rem U U
mul V U 696235320
rem V V
add U U V
add C T U

*/

const N: u64 = 1000000007;
const R: u64 = 998244353;
const NEG_N: u64 = 18446744072709551609; // 2 ** 64 - n; 
const N_DASH: u64 = 4915446; // n * ndash % r == r - 1
const N_INV: u64 = 993328907; // n * ninv % r == 1
const R_INV: u64 = 996491785301655553; // r * rinv % (2**64) == 1
const R2: u64 = 320946142; // == r * r % n
//const RC: u64 = 8456992263029997534; // R_INV * R2 % (2**64)
const Q: u64 = 932051910; // 2**64 % R
const P: u64 = 696235320; // Q * P % R + Q == N

fn reduction1(t: u64) -> u64 {
    // t := (T + (((T % R) * N_DASH % R) * N)) * R_INV (0 <= t < T/R + N)
    let u = t % R;
    let u = u.wrapping_mul(N_DASH);
    let u = u % R;
    let u = u.wrapping_mul(N);
    let u = u.wrapping_add(t);
    let _t = u.wrapping_mul(R_INV);
    assert_eq!(_t * R % N, t % N);
    _t
}

fn reduction2(t: u64) -> u64 {
    // t := (T + (((T % R) * N_INV % R) * NEG_N)) * R_INV (-N < t < T/R)
    let u = t % R;
    let u = u.wrapping_mul(N_INV);
    let u = u % R;
    let u = u.wrapping_mul(NEG_N);
    let u = u.wrapping_add(t);
    let _t = u.wrapping_mul(R_INV);
    assert_eq!(_t.wrapping_add(N) * R % N, t % N);
    _t
}

fn fix(t: u64) -> u64 { // if t < 0 then return t + N else return t + 0
    let u = t.wrapping_mul(R);
    let u = u % R;
    assert!(((t as i64) < 0 && u == Q) || ((t as i64) >= 0 && u == 0));
    let v = u.wrapping_mul(P);
    let v = v % R;
    let u = u.wrapping_add(v);
    assert!(((t as i64) < 0 && u == N) || ((t as i64) >= 0 && u == 0));
    let _c = t.wrapping_add(u);
    _c
}

pub fn mulmod(a: u64, b: u64) -> u64 {
    let t = reduction1(a * b);
    let t = reduction2(t * R2);
    let c = fix(t);
    assert_eq!(a * b % N, c);
    c
}
Mizar/みざーMizar/みざー

18行に短縮できました。reduction2 の最後の /R/R と fix最初の R*R がもろ被りだったので、ここを除去すれば18行になります。以下回答例コードと概念コード。

https://atcoder.jp/contests/arc148/submissions/34818980

/*
18
mul T A B
rem U T
mul U U 4915446
rem U U
mul U U 1000000007
add U U T
mul T U 8456992263029997534
rem U T
mul U U 993328907
rem U U
mul U U 18446744072709551609
add U U T
mul T U 996491785301655553
rem U U
mul V U 696235320
rem V V
add U U V
add C T U
*/

const N: u64 = 1000000007;
const R: u64 = 998244353;
const NEG_N: u64 = 18446744072709551609; // 2 ** 64 - n; 
const N_DASH: u64 = 4915446; // n * ndash % r == r - 1
const N_INV: u64 = 993328907; // n * ninv % r == 1
const R_INV: u64 = 996491785301655553; // r * rinv % (2**64) == 1
const R2: u64 = 320946142; // == r * r % n
//const RC: u64 = 8456992263029997534; // R_INV * R2 % (2**64)
const Q: u64 = 932051910; // 2**64 % R
const P: u64 = 696235320; // Q * P % R + Q == N

fn reduction1(t: u64) -> u64 {
    // t := (T + (((T % R) * N_DASH % R) * N)) * R_INV (0 <= t < T/R + N)
    let u = t % R;
    let u = u.wrapping_mul(N_DASH);
    let u = u % R;
    let u = u.wrapping_mul(N);
    let u = u.wrapping_add(t);
    let _t = u.wrapping_mul(R_INV);
    assert_eq!(_t * R % N, t % N);
    _t
}

fn reduction2(t: u64) -> (u64, u64) { // t の値域が異なるモンゴメリ剰余乗算の変種を使う
    // t := (T + (((T % R) * N_INV % R) * NEG_N)) * R_INV (-N < t < T/R)
    let u = t % R;
    let u = u.wrapping_mul(N_INV);
    let u = u % R;
    let u = u.wrapping_mul(NEG_N);
    let u = u.wrapping_add(t);
    let _t = u.wrapping_mul(R_INV);
    assert_eq!(_t.wrapping_add(N) * R % N, t % N);
    (_t, u)
}

fn fix(t: u64, u: u64) -> u64 { // if t < 0 then return t + N else return t + 0
    //let u = t.wrapping_mul(R); // reduction2 の最後で /R しているのでその前から値を貰う
    let u = u % R; // t が負なら Q=2**64%R 、 t が非負なら 0 の値を u に代入
    assert!(((t as i64) < 0 && u == Q) || ((t as i64) >= 0 && u == 0));
    let v = u.wrapping_mul(P);
    let v = v % R;
    let u = u.wrapping_add(v); // t が負なら N 、 t が非負なら 0 の値を u に代入
    assert!(((t as i64) < 0 && u == N) || ((t as i64) >= 0 && u == 0));
    let _c = t.wrapping_add(u); // t が負なら t+N 、 t が非負なら t の値を _c に代入
    _c
}

pub fn mulmod(a: u64, b: u64) -> u64 {
    let t = reduction1(a * b);
    let (t, u) = reduction2(t * R2);
    let c = fix(t, u);
    assert_eq!(a * b % N, c);
    c
}
このスクラップは2022/09/12にクローズされました
ログインするとコメントできます