ã°ã©ãäœãã®äŸã§ãæ¥ç¶è·é¢ïŒæ¥ç¶å ã®æ°ïŒãå¶éããã°ã幟ã€ãã®éšåã°ã©ãïŒãµãã°ã©ãïŒãåŸãããããšãåãã£ãããã®éšåã°ã©ããååŸããŠã空éçãªã¯ã©ã¹ã¿ãŒãšããŠæããããšãã§ããã
ããã§ã¯ã¯ã©ã¹ã¿ãŒã«å±ããããŒããååŸããäŸã玹ä»ããã
ã é£æ¥è¡å/edge_indexãç¡åã°ã©ãçšã«å€æ ã
ããããŒããããšããžã蟿ã£ãŠèŸ¿ãçãããšãã§ããããŒãéåãããã®ããŒããå±ããã¯ã©ã¹ã¿ãŒïŒãµãã°ã©ãïŒãšãªããã¯ã©ã¹ã¿ãŒå ã®ã©ã®ããŒããã蟿ã£ãŠãã¯ã©ã¹ã¿ãŒå å šãŠã®ããŒãã«å°éå¯èœã§ããã
åç« ãŸã§ã®edge_indexã§ã¯çæ¹åã®æåã°ã©ãã®æ å ±ã«ãªã£ãŠããã®ã§ãããããŒãXãæ¥ç¶å ããŒããšããŠç»é²ãããŠããŠãããã®ããŒãXã«ã¯æ¥ç¶å ããŒããç¡ãããšãããã
3ã€ã®ããŒãã¯åãã¯ã©ã¹ã¿ãŒã«ãªã£ãŠæ¬²ããããæåã°ã©ãã ãšããŒã1ããããŒã0,2ãžèŸ¿ããªãã
éæ¹åã®æ¥ç¶ãäºåã«ç»é²ããŠããŒãã¯ã©ã¹ã¿ãŒãèŠã€ããããããŠããã
def undirected_edge_index(
edge_index
):
if isinstance(edge_index,list):
edge_index = np.array(edge_index)
# æ¥ç¶å
ãšæ¥ç¶å
ãå
¥ãæ¿ãããã®ãè¿œèšãã
edge_index = np.concatenate([
edge_index,
edge_index[:,[1,0]]
])
# éè€ãé€ã
edge_index = np.unique(edge_index, axis=0)
return edge_index
ãã¢çšã«ãäžèšã§è·é¢éŸå€ãèšããããããŒã°ã©ãã®edge_indexã䜿çšããã
edge_index_filtered = undirected_edge_index(edge_index_filtered)
é£æ¥è¡åã®å Žå
転眮ããŠå ç®ããŠã0ã§ã¯ç¡ãç®æã«1ãå ¥ãããšãç¡åã°ã©ãçšã®å¯Ÿç§°è¡åãåŸãããã
adj_matrix_undirected = adj_matrix + adj_matrix.T
adj_matrix_undirected[adj_matrix_undirected > 0] = 1
ã 1ã€ã®ããŒããã蟿ãçããããŒãéå ã
ãŸãã¯1ã€ã®ããŒãããã¯ã©ã¹ã¿ãŒãåãåºãæ©èœã玹ä»ãããè¿ãå€ã¯ã¯ã©ã¹ã¿ãŒã«å±ãããšããžéåã§ããã
def find_connected_cluster(start_node, edge_index):
"""
1ã€ã®ã¹ã¿ãŒãããŒããã蟿ãçãããšãã§ããéšåã°ã©ãã®ãšããžãåãåºãæ©èœ
Return:
ãµãã°ã©ãã®ããŒãçªå·ãªã¹ã
"""
def get_connected_edges(start_node, edge_index):
connected_edges = [ edge for edge in edge_index if start_node in edge]
return connected_edges
# æ¢çŽ¢æžã¿ãèšé²ããçšãããŒãæ°ã®é·ãã®ãªã¹ãã«boolå€ãèšé²
num_nodes = np.unique(list(edge_index)).max()
visited_nodes = np.zeros(num_nodes, dtype=bool)
# start nodeã®ç®æãTrueã«ããŠãã
visited_nodes[start_node] = True
# æçµåºåã®ãªã¹ãããµãã°ã©ãå
ã®å
šãšããžãèšé²
cluster = []
connected_nodes = {start_node}
## æ¥ç¶å
ããŒããç¡ããªããŸã§ç¹°ãè¿ã
while connected_nodes:
new_edges = []
new_nodes = set()
for node in sorted(connected_nodes):
# 1ã€ã®ããŒãã«ç¹ãããšããžãååŸ
edges = get_connected_edges(node, edge_index)
# åãšããžãæ¢çŽ¢æžã¿ããŒãã®ãã®ããå€å®ãæªæ¢çŽ¢ã§ããã°èšé²
for edge in edges:
if not (visited_nodes[edge[0]] and visited_nodes[edge[1]]):
new_edges.append(edge)
new_nodes.update(edge)
# æªæ¢çŽ¢ã®ãšããžããªããã°whileãçµäº
if not new_edges:
break
# clusterã«ãšããžãèšé²
cluster.extend(new_edges)
# åæ¢çŽ¢ããŒããæŽæ°ããŠæ¬¡ã®loopã«æž¡ã
connected_nodes = new_nodes - set(np.where(visited_nodes)[0]) # æ¢çŽ¢æžã¿ã¯é€ã
visited_nodes[list(connected_nodes)] = True # ãã®loopã§æ¢çŽ¢ããããŒãã¯æ¢çŽ¢æžã¿ãšãã
return cluster
äŸãã°ããŒãçªå·0ãã蟿ãããšããžãå šãŠååŸããŠã¿ãã
cluster_edges = find_connected_cluster(start_node=0, edge_index=edge_index_filtered)
ãã®ãšããžéåãããŠããŒã¯ãªããŒãçªå·ãåãã°ã1ã€ã®ã¯ã©ã¹ã¿ãŒã®ããŒãéåãååŸã§ããã
cluster_nodes = np.unique(cluster_edges)
ã å šãŠã®ããŒãã¯ã©ã¹ã¿ãŒãååŸ ã
ã°ã©ãå ã®å šãŠã®éšåã°ã©ããååŸããæ©èœã玹ä»ããã
def find_clusters(
edge_index
):
clusters = []
nodes = np.unique(edge_index)
visited = set() # æ¢çŽ¢æžã¿ããŒããã©ããã®èšé²çš
for n in nodes:
# æ¢çŽ¢æžã¿ãªãskip
if n not in visited:
# ããŒãnã«æ¥ç¶ããããšããžãèŠã€ãã
edges = find_connected_cluster(n, edge_index)
if edges:
# ã¯ã©ã¹ã¿ãŒã«å±ããããŒããèŠã€ãã
cluster_nodes = list(np.sort(np.unique(edges)))
clusters.append(cluster_nodes)
# ã¯ã©ã¹ã¿ãŒã®ããŒããvisitedã«è¿œå
visited.update(cluster_nodes)
else:
# å€ç«ããŒããšããŠåå¥ã«è¿œå
clusters.append([n])
visited.add(n)
return clusters
clusters = find_clusters(edge_index_filtered)
ã°ã©ãããéšåã°ã©ãã®ããŒãã¯ã©ã¹ã¿ãŒãååŸ
ã Union-Findã¢ã«ãŽãªãºã ã
ã°ã©ãããéšåã°ã©ããèŠã€ããé«éãªã¢ã«ãŽãªãºã ãã©ã®ããŒããšã©ã®ããŒããæ¥ç¶ããŠããã®ãã®æ å ±ïŒedge_indexïŒãå ã«ãç¹ãã£ãéåã«åãã¯ã©ã¹ã¿ãŒçªå·ãäžããã
詳现ã¯ä»¥äžã®åç §ãåèã«ããŠã»ããã
https://zenn.dev/kaityo256/articles/union_find_physics
https://qiita.com/ofutonton/items/c17dfd33fc542c222396
https://atcoder.jp/contests/atc001/tasks/unionfind_a
https://zenn.dev/convers39/articles/ffd666639e7782
ãŸãã¯Union-Findã¢ã«ãŽãªãºã ãå®è£ ããã¯ã©ã¹ãçšæããã
class UnionFind:
def __init__(self, n_node):
self.n_node = n_node
# åæã¯ã©ã¹ã¿ãŒçªå·ïŒnodeã®æ°ã ãçšæããããŒãçªå·=ã¯ã©ã¹ã¿ãŒçªå·ã®ç¶æ
ïŒ
self.parent = list(range(n_node))
# åããŒãã®æé·çµè·¯é·ãä¿åããçš
self.rank = [1] * n_node
def find(self, u):
"""
ããŒãçªå·uã®ã¯ã©ã¹ã¿ãŒçªå·ãèŠã€ããæ©èœ
åæç¶æ
ã§ããã°ããŒãçªå·uã®ã¯ã©ã¹ã¿ãŒçªå·ã¯u
ã¯ã©ã¹ã¿ãŒçµ±åãé²ããšãããŒãçªå·uã®ã¯ã©ã¹ã¿ãŒçªå·ãå€ãã£ãŠããã
ããŒãçªå·uã®ã¯ã©ã¹ã¿ãŒçªå·ãååž°çã«æ¢ãã
"""
# ããŒãçªå·uã®ã¯ã©ã¹ã¿ãŒçªå·ãuã§ãªããã°
if self.parent[u] != u:
# ããŒãçªå· == ã¯ã©ã¹ã¿ãŒçªå·ã®ç®æãååž°çã«æ¢ããŠããã®ã¯ã©ã¹ã¿ãŒçªå·ãäžæžã
self.parent[u] = self.find(self.parent[u])
# ããŒãçªå·uã®ã¯ã©ã¹ã¿ãŒçªå·ãè¿ã
return self.parent[u]
def union(self, u, v):
"""
æ¥ç¶ããããŒãïŒæ¥ç¶å
uãæ¥ç¶å
vïŒã1ã€ã®ã¯ã©ã¹ã¿ãŒçªå·ã«çµ±äžããæ©èœ
"""
# ããŒãçªå·uã®ã¯ã©ã¹ã¿ãŒçªå·ãè¿ã
root_u = self.find(u)
# ããŒãçªå·vã®ã¯ã©ã¹ã¿ãŒçªå·ãè¿ã
root_v = self.find(v)
# ããããã®ã¯ã©ã¹ã¿ãŒçªå·ãç°ãªãå Žåãã©ã³ã¯å€ãé«ãã¯ã©ã¹ã¿ãŒçªå·ã«æãã
if root_u != root_v:
if self.rank[root_u] > self.rank[root_v]:
self.parent[root_v] = root_u
elif self.rank[root_u] < self.rank[root_v]:
self.parent[root_u] = root_v
else:
# ã©ã³ã¯ãåãã§ããã°ãã©ã¡ãããæ¡çšããŠã©ã³ã¯å€ãå¢ãã
self.parent[root_v] = root_u
self.rank[root_u] += 1
def find_cluster(self,start_node):
"""
start nodeãšåãã¯ã©ã¹ã¿ãŒã«å±ããå
šãŠã®ããŒããè¿ãæ©èœ
"""
cluster = self.find(start_node)
# åãã¯ã©ã¹ã¿ãŒã«å±ããããŒãããªã¹ãã¢ãã
cluster_nodes = [node for node in range(len(self.parent)) if self.find(node) == cluster]
return cluster_nodes
def find_all_cluster(self, check_nodes=None):
"""
self.parentã®clusteræ
å ±ã䜿ã£ãŠãããŒããgrouping
"""
from collections import defaultdict
from tqdm.auto import tqdm
clusters = defaultdict(list)
if check_nodes is None:
check_nodes = list(range(0,self.n_node))
# åããŒãã®ã¯ã©ã¹ã¿ãŒçªå·ã§ããŒãããŸãšãã
for node in tqdm(check_nodes):
root = self.parent[node]
clusters[root].append(node)
return list(clusters.values())
ã€ã³ã¹ã¿ã³ã¹å
ã€ã³ã¹ã¿ã³ã¹åã«ã¯ããŒãæ°ãæå®ããã ãã§ãããåæç¶æ ã®ã€ã³ã¹ã¿ã³ã¹ã¯ãããŒãã®æ°ã ãã¯ã©ã¹ã¿ãŒããããïŒã¡ã³ãå±æ§ã®parentã§ã¯ã©ã¹ã¿ãŒçªå·ã管çããŠãããïŒ
n_node = np.unique(list(edge_index_filtered)).max() + 1
uf = UnionFind(n_node=n_node)
åæç¶æ
ã®ã¯ã©ã¹ã¿ãŒçªå·
ããŒãã®ã¯ã©ã¹ã¿ãŒçªå·ãæŽæ°
ããã«ã©ã®ããŒããšã©ã®ããŒããç¹ãã£ãŠãããã®æ¥ç¶æ å ±ãæž¡ããŠãã¯ã©ã¹ã¿ãŒçªå·ãæŽæ°ããã
for u, v in edge_index_filtered:
uf.union(u, v)
52åã®ã¯ã©ã¹ã¿ãŒãããããšãããã£ããïŒ1ã€ã®ããŒãããæãã¯ã©ã¹ã¿ãŒããããïŒ
ããŒãã¯ã©ã¹ã¿ãŒãååŸ
find_cluster()
ã¡ãœããã§1ã€ã®ããŒãããå°éå¯èœãªããŒãéåãååŸã§ããã
cluster0 = uf.find_cluster(start_node=0)
find_all_cluster()
ã§ã¯ã°ã©ãå
ã®å
šãŠã®ããŒãã¯ã©ã¹ã¿ãŒãè¿ãã
clusters = uf.find_all_cluster()