🌵

[ABC258G] Triangle を Pythonで解く

2022/07/07に公開

https://atcoder.jp/contests/abc258/tasks/abc258_g

考え方

ij を固定して i - j 間に辺がある場合のみを考えます。さらに i - k 間、j - k 間に辺をもつ k を数えることは、隣接行列 A の行成分 A[i] , A[j] を二進数としてみて、AND演算した場合に立っているビットを数える(ビットカウントをする)ことと同じことになります。
C++では、ビットカウントを高速で行うことができるのですが、Pythonだと通常の実装では時間がかかります。下記のコードは出力は正しいですが、このままだとTLEになりますので、bin(a).count("1")のところを高速化する必要があります。

N = int(input())
A = [int(input(),2) for _ in range(N)]

ans = 0
for i in range(N-1):
    for j in range(i+1,N):
        if A[i] >> (N-j-1) & 1 == 0:
            continue
        a = A[i] & A[j]
        ans += bin(a).count("1")

print(ans//3)

実装メモ

この後で実装するビットカウントを行う関数(popcnt)が、64桁までしか対応しません。本問の制約は N=3000 までとなっていますので、A を50桁ごとに分割した A2 を準備して、64桁までの実装で処理できるようにします。

N = int(input())
A = [input() for _ in range(N)]

def int_bin(x): 
    return int(x,2) if x else 0
A2 = [[int_bin(A[i][60*j:60*j+60]) for j in range(50)] for i in range(N)]

Pythonでもビットカウントを高速で処理できる関数を準備します。ビットカウントの原理については、以下の記事がとても参考になります。
https://nixeneko.hatenablog.com/entry/2018/03/04/000000

https://www.slideshare.net/KMC_JP/slide-www
ちなみに、繰り返しが気になって、MSK用の整数をMSK[i]に入れて、for文をまわして処理したらTLEになってしまいました。こんなことでも遅くなるんだなと勉強になりました。

MSK1 = 0x5555555555555555
MSK2 = 0x3333333333333333
MSK3 = 0x0f0f0f0f0f0f0f0f
MSK4 = 0x00ff00ff00ff00ff
MSK5 = 0x0000ffff0000ffff
MSK6 = 0x00000000ffffffff
 
def popcnt(x):
  x=(x & MSK1) + ((x>>1) & MSK1)
  x=(x & MSK2) + ((x>>2) & MSK2)
  x=(x & MSK3) + ((x>>4) & MSK3)
  x=(x & MSK4) + ((x>>8) & MSK4)
  x=(x & MSK5) + ((x>>16) & MSK5)
  x=(x & MSK6) + ((x>>32) & MSK6)
  return x
16進数表記(16桁) 2進数表記(64桁)
5555555555555555 0101010101010101010101010101010101010101010101010101010101010101
3333333333333333 0011001100110011001100110011001100110011001100110011001100110011
0f0f0f0f0f0f0f0f 0000111100001111000011110000111100001111000011110000111100001111
00ff00ff00ff00ff 0000000011111111000000001111111100000000111111110000000011111111
0000ffff0000ffff 0000000000000000111111111111111100000000000000001111111111111111
00000000ffffffff 0000000000000000000000000000000011111111111111111111111111111111

最後に3で割る理由は、i<j で処理しており、3つの整数の順列である6通りのうち半分の3通りが重複するからです。

ans = 0
for i in range(N-1):
    for j in range(i+1,N):
        if A[i][j] == "0":
            continue

        for t in range(50):
            ans += popcnt(A2[i][t] & A2[j][t])

print(ans//3)

Discussion