🤪

ABC_406_E 'Popcount Sum 3'をRustで解く

に公開

今回はRubyではなくRustを使います。理由はカッコいいからです。

問題

N以下、かつpop countがKの整数を全部合計してください。答えは998244353の剰余で

https://atcoder.jp/contests/abc406/tasks/abc406_e

解法

いくつか存在しますが、私の観測範囲で最も一般的なのは桁DPでした。

https://x.com/yiwiy9/status/1924810348676448564

https://atcoder.jp/contests/abc406/editorial/13071

公式解説だと誤字なのか何なのか変な記法になっているのでメモします現在修正済み
https://x.com/hotpepsi/status/1924797923390579159

桁DPで「N以下の自然数のうちpop countがいくらの自然数が何個か」を考えるのは難しいですが、実は「以下」ではなく「未満」の場合は簡単に解けます。
以降、テストケースに合わせてNを20、Kを2として例示します。

アルゴリズム

第一に、20を二進法で書くと10010、16の位と4の位のbitが1になっています。また、この両方のbitが立っている自然数は絶対に20以上になってしまいます。
逆(対偶)に言うと、20未満の自然数をbit単位で見ていった場合、16の位と4の位の片方は絶対に0です。
「二進法の16の位と4の位の両方が1ならば、示している整数は20以上である」の対偶で「整数が20未満の場合、二進法の16の位と4の位の片方は0」です。

また第二に、これは実験してみないと直感的にはわかりにくいですが
「20では1だった桁を0にした場合、それより下の桁だけは自由にして良い。上の桁をいじったら駄目」のルールがあります。たとえば17は二進法で10001。4の位の桁を1から0にしているため、2の位と1の位は自由に決められます。一方で24は二進法で11000と4の位が0なのに20以上、これは4の位より上のbitを操作してしまったためです。

以上から、20未満の自然数をbitごとに分類する場合、以下2つのルールでdpを書けます。

  1. 現在より上の桁の時点で、既に20で1だったbitを0にしているので自由に選択できる
  2. ここまでは20と同じビット状況だったが、20の二進数で1だった桁を今回初めて0にする

dp

行が今見ている桁、列がpop count数、値が「その桁、そのbit数で20未満の自然数の集合」です。

32の位

欲しい自然数は20未満なので、32の位(およびそれ以上)は空集合です。
0ですらありません。「32の位に0が入っていたら、それ以降は何をどうやっても20未満」とはならないからです。

確定済みの位 pop count 0 pop count 1 pop count 2
32の位 [] [] []
16の位

ここで先程示したルール2を使います。16の位が0である事を確定させます。するとそれ以降はどうbitを立てても20未満である事が確定します

確定済みの位 pop count 0 pop count 1 pop count 2
32の位 [] [] []
16の位 [0] [] []
8の位

20は8の位が0なので、ルール1のみです。16の位が0だと確定している整数が一個あるので、8の位が0と1両方の場合を作ります。1ビット立てたらpop countも1上がる事に注意。

確定済みの位 pop count 0 pop count 1 pop count 2
32の位 [] [] []
16の位 [0] [] []
8の位 [0] [8] []
4の位

まずルール1。0と8のそれぞれに4の位が0と1のバージョンを追加します。

確定済みの位 pop count 0 pop count 1 pop count 2
32の位 [] [] []
16の位 [0] [] []
8の位 [0] [8] []
4の位 [0] [8, 4] [12]

20の4の位は1なので、さらにここからルール2。「16の位は1だが4の位は0。故にこれ以降はどうビットを立てても20未満確定」を追加します。

確定済みの位 pop count 0 pop count 1 pop count 2
32の位 [] [] []
16の位 [0] [] []
8の位 [0] [8] []
4の位 [0] [8, 4, 16] [12]
2の位

20は2の位が0なので、ルール1のみです。0,8,4,16にそれぞれ2の位が1と0のバージョンを追加します。12は既にpop countが2なので、14にしたらpop countが3になってしまい求める出力に関係なくなってしまうので略します。誰かpop count 4まで書かれた記事を書いてください。

確定済みの位 pop count 0 pop count 1 pop count 2
32の位 [] [] []
16の位 [0] [] []
8の位 [0] [8] []
4の位 [0] [8, 4, 16] [12]
2の位 [0] [8, 4, 16, 2] [12, 10, 6, 18]
1の位

やる事は2の位と同じなので略。

確定済みの位 pop count 0 pop count 1 pop count 2
32の位 [] [] []
16の位 [0] [] []
8の位 [0] [8] []
4の位 [0] [8, 4, 16] [12]
2の位 [0] [8, 4, 16, 2] [12, 10, 6, 18]
1の位 [0] [8, 4, 16, 2, 1] [12, 10, 6, 18, 9, 5, 17, 3]

dpは「20未満」、問題文は「20以下」なので20だけ別に計算します。これで問題文の例と同じ[12, 10, 6, 18, 9, 5, 17, 3, 20]の集合が得られます。

高速化

ここから高速化します。
先程のテーブルのうち、4の位から2の位への遷移を見てみます。

確定済みの位 pop count 0 pop count 1 pop count 2
4の位 [0] [8, 4, 16] [12]
2の位 [0] [8, 4, 16, 2] [12, 10, 6, 18]

求めるのは総和なので、加算に書き換えてみます。

確定済みの位 pop count 0 pop count 1 pop count 2
4の位 '0' '8+4+16' '12'
2の位 '0' '8+4+16+2' '12+10+6+18'

4の位のpop count 1の8+4+16がわかりやすいでしょうか?

  • 2の位のpop count 1において、前3つはそのまま降りてきています。(8+4+16)+2と表現できます。
  • 2の位のpop count 2において、後ろ3つは8+4+16のそれぞれに2を足した物です。12+((8+2)+(4+2)+(16+2))と書けます。さらに+2を一箇所にまとめて12+((8+4+16)+2*3)と表現できます。公式解説だとここが変な記法になっています現在修正済み

以上の事から、「集合の合計値」と「集合の要素数」だけ記憶しておけば集合の要素一個一個の数字は忘れてもいい事がわかります。

確定済みの位 pop count 0 pop count 1 pop count 2
4の位 {合計0, 要素数1} {合計28, 要素数3} {合計12, 要素数1}

これだけあれば、

  • 2の位はまず一行上をそのまま降ろして{合計0, 要素数1}, {合計28, 要素数3}, {合計12, 要素数1}
  • 4の位のpop count 0の要素それぞれの2の位を1にする。合計は0 + 2 * 1 = 2、要素数は1。これを2の位のpop count 1に足して{合計0, 要素数1}, {合計30, 要素数4}, {合計12, 要素数1}
  • 4の位のpop count 1の要素それぞれの2の位を1にする。合計は28 + 2 * 3 = 34、要素数は3。これを2の位のpop count 2に足して{合計0, 要素数1}, {合計30, 要素数4}, {合計46, 要素数4}
確定済みの位 pop count 0 pop count 1 pop count 2
4の位 {合計0, 要素数1} {合計28, 要素数3} {合計12, 要素数1}
2の位 {合計0, 要素数1} {合計30, 要素数4} {合計46, 要素数4}

ルール1はこれで計算できるわけです。一方でルール2は若干ややこしいです。

確定済みの位 pop count 0 pop count 1 pop count 2
8の位 {合計0, 要素数1} {合計8, 要素数1} {合計0, 要素数0}
  • まずルール1を全部計算して{合計0, 要素数1}, {合計12, 要素数2}, {合計12, 要素数1}
  • これにルール2を追加。「4の位まで確定した16。pop countは1」なので4の位のpop count 1に合計16,要素数1を追加。{合計0, 要素数1}, {合計28, 要素数3}, {合計12, 要素数1}
確定済みの位 pop count 0 pop count 1 pop count 2
8の位 {合計0, 要素数1} {合計8, 要素数1} {合計0, 要素数0}
4の位 {合計0, 要素数1} {合計28, 要素数3} {合計12, 要素数1}

公式解説では3次元配列ではなく2次元配列を二個使い、要素数配列をdp1、合計配列をdp2としています。

ACコード

https://atcoder.jp/contests/abc406/submissions/66005333

ここまでテーブル全体を記録していましたが、コード内では作業する行と一行前の二行以外は消しています。Rubyの使いすぎです。
29行目と30行目で入力を受け取っています。
60桁繰り返しても正解になりますが、ループ回数を抑えるために31行目でNのビット数を調べています。記事の中で32の位より上を計算していないのと同じです。
作業するdp行がcurrent、一行上の記録として残しておく行がprev
45行目から51行目までがルール1で、81行目までがルール2です。
68行目にごちゃごちゃやってますが、これはNと111...(下から数えて0埋めしたい桁数)...1を&演算してNの下digit_from_bottom桁を抜き出し、Nからこれを引いて下digit_from_bottom桁を0埋めした数字です。もっとスマートにやる計算方法があるのでしょうか? 誰か教えてください。
74行目から77行目までの早期コンティニューは、pop countの値が列の長さを超えてしまっていた場合です。コメントで「立っている桁数が足りない場合」とありますがこれは逆で、「求めるpop count値に対してNを上からi桁目まで再現した時のpop count値が多すぎる場合」です。
85行目は「以下」と「未満」の違いを埋めています。記事内で20ぴったりが答えになるので最後に20を追加しているのと同じです。


Rustを本番で使いこなせる気がしません。

Discussion