Open5

競プロ「初中級者が解くべき過去問精選 100 問」全探索:ビット全探索

ドリクロドリクロ

10. ALDS_5_A - 総当たり

main.py
import sys

def main():
    N = int(sys.stdin.readline().rstrip())
    A = list(map(int, sys.stdin.readline().rstrip().split()))
    Q = int(sys.stdin.readline().rstrip())
    M = list(map(int, sys.stdin.readline().rstrip().split()))
    
    result = set()
    for i in range(2**N):
        total = 0
        for j in range(N):
            if ((i >> j) & 1):
                total += A[j]
                result.add(total)
    
    for m in M:
        if m in result:
            print("yes")
        else:
            print("no")

if __name__ == '__main__':
    main()

実行結果

n≤20, q≤200, 1≤Aの要素≤2000, 1≤mi≤2000
O(2^N + M)

参考

https://qiita.com/gogotealove/items/11f9e83218926211083a

ドリクロドリクロ

11. AtCoder Beginner Contest 128 C - Switches

main.py
#!/usr/bin/env python3

import sys

def main():
    N, M = list(map(int, sys.stdin.readline().rstrip().split()))
    K, S = [], []
    for i in range(M):
        ks = tuple(map(int, sys.stdin.readline().rstrip().split()))
        K.append(ks[0])
        S.append(ks[1:])
    P = list(map(int, sys.stdin.readline().rstrip().split()))
    
    result = 0
    for i in range(2**N):
        on_off = set()
        for j in range(N):
            if ((i >> j) & 1):
                on_off.add(j+1)
        for k, s, p in zip(K, S, P):
            match = 0
            for kk in range(k):
                if s[kk] in on_off:
                    match += 1
            if match % 2 != p:
                break
        else: result += 1
    
    print(result)

if __name__ == '__main__':
    main()

実行結果

1≤N, M≤10
O(NM * 2^N)

解説

ドリクロドリクロ

12. AtCoder Beginner Contest 002 D - 派閥

main.py
#!/usr/bin/env python3

import sys
import itertools

def main():
    N, M = map(int, sys.stdin.readline().rstrip().split())
    FRIEND_XY = {tuple(map(int, sys.stdin.readline().rstrip().split())) for _ in range(M)}
    
    result = 0
    for bit in range(2**N):
        xy_group = []
        for i in range(N):
            if (bit >> i) & 1:
                xy_group.append(i+1)
        flag = True
        for x, y in itertools.combinations(xy_group, 2):
            if not {(x, y)} <= FRIEND_XY:
                flag = False
                break
        if flag:
            result = max(result, len(xy_group))
    
    print(result)

if __name__ == '__main__':
    main()

実行結果

1≦N≦12, 0≦M≦(N(N−1)/2)
O(N*2^N)

解説

個人的に沼にはまったエラー

main.py
#!/usr/bin/env python3
 
import sys
import itertools
 
def main():
    N, M = map(int, sys.stdin.readline().rstrip().split())
    FRIEND_XY = {tuple(map(int, sys.stdin.readline().rstrip().split())) for _ in range(M)}
    
    result = 0
    for bit in range(2**N):
        xy_group = set()
        for i in range(N):
            if (bit >> i) & 1:
                xy_group.add(i+1)
        flag = 1
        for x, y in itertools.combinations(xy_group, 2):
            if not {(x, y)} <= FRIEND_XY:
                flag = 0
                break
        if flag:
            result = max(result, len(xy_group))
    
    print(result)
 
if __name__ == '__main__':
    main()

実行結果

xy_group = set() Ver.
始めにこのコードを提出して、WA に……
xy_group = [] にすると AC に……

おそらく set() の重複する要素を一つにする特性が悪さをしたのかと考察します……。