유니온-파인드와 최소 신장 트리

유니온-파인드, 크루스칼, 프림 알고리즘

Lifthus
7 min readDec 5, 2023
MST
MST

각 요소들이 속하는 집합을 다루는 Union-find와, 이를 이용해 최소 신장 트리를 Kruskal 알고리즘, 그리고 최소 신장 트리를 구하는 다른 방법인 Prim 알고리즘에 대해 알아보자.

Union-find

이 알고리즘은 간단히 각 요소들이 속하는 집합을 병합하고 어떤 집합에 속하는지 찾아내는 알고리즘이다. 최소 기능만 구현하면 병합이나 탐색 과정에서 트리 높이가 계속 높아져서 O(N)에 도달할 수 있기 때문에 최적화가 적용된 버전에 대해 알아보자.

먼저 각 집합의 루트 노드는 해당 집합의 사이즈만큼의 음수를 요소로 가지고, 집합에 포함되는 다른 노드는 루트로 향하는 부모 인덱스를 요소로 가지도록 표현하자. 그럼 처음에는 모든 요소가 각각 자신의 집합을 형성하기 때문에 -1이 될 것이다.

find 함수는 이제 루트를 탐색하기 위해 음수 요소를 가진 노드를 발견할 때 까지 재귀적으로 부모 인덱스를 따라가면 되고, union 함수에서는 병합하는 루트가 병합 당하는 루트의 음수 사이즈 요소를 자신에게 그대로 더하고 병합 당하는 노드 쪽은 새로운 루트의 인덱스를 요소로 가지면 된다.

find 함수에서 최적화할 부분은, 재귀적으로 부모 인덱스를 따라 가는 동시에 최종적으로 발견되는 부모 인덱스로 경로 상의 모든 인덱스를 갱신하는 것이다. 이렇게 하면 나중에는 이 경로를 다 따라가지 않고 바로 루트를 발견할 수 있다. 이 기법을 path compression이라고 한다.

union 함수에서는 위에 음수로 표현한 사이즈에 기반한 union by rank 기법을 사용할 수 있다. 핵심 원리는 더 큰 트리 밑에 더 작은 트리를 붙이는 것이다. 따라서 루트 요소가 나타내는 음수에 기반해서 더 작은 트리의 부모 인덱스를 더 큰 트리의 루트로 교체한다. 트리를 붙이면 병합 당하는 쪽은 깊이가 한 칸 깊어지는데, 당연히 하나라도 깊이가 덜 깊어지는 편이 효율적이기 때문에 작은 트리를 큰 트리 밑에 붙이는 것이다. rank를 깊이로 설정할 수도 있는데 그 경우도 당연히 깊이가 더 깊은 루트에 얕은 트리를 붙이면 트리가 더 깊어지는 것을 최소화할 수 있을 것이다.

코드를 보자.

# 각 루트, 최초 사이즈에 따라  -1 로 설정
p = [-1 for i in range(6)]

def find(x):
if p[x]>=0: # 루트가 아니면
p[x]=find(p[x]) # 루트를 찾아서 갱신
return p[x]
else:return x # 음수라서 루트면 인덱스 그대로 반환

def union(a,b):
pa = find(a)
pb = find(b)
# 둘 다 루트라서 사이즈의 음수, 즉 a 집합 사이즈가 더 작으면 스왑
if p[pa]>p[pb]:pa,pb=pb,pa
# 조건부 스왑으로 항상 a 집합 사이즈가 더 큼
p[pa]+=p[pb] # a 집합에 b 집합 사이즈 더하고
p[pb]=pa # 더 작은 b 집합의 루트는 a 집합 루트로 갱신

union(2,3)
union(1,2)
union(0,1)
union(4,5)
print(p)
print(find(3))
print(p)

전체 노드 수가 N일 때 find의 시간복잡도는 트리의 높이에 비례하는데, union이 항상 더 큰 트리에 작은 트리를 붙이기 때문에 트리를 최대한 높게 만들려하면 이진 트리 형태로 갈 수 밖에 없기 때문에 O(logN)이 시간복잡도가 되고, union의 시간복잡도는 O(2logN+상수)라 결국 O(logN)으로 동일하다. 다만 최악의 경우들을 평균내면 α가 아커만 함수의 역수일 때 O(α(n))이라고 하는데 이는 사실 상 일반적인 경우에서 상수 시간이나 다름 없다고 한다.

Kruskal’s algorithm

Kruskal 알고리즘은 무방향 그래프에서 최소 신장 트리를 구하는 알고리즘이다. 최소 신장 트리는 간단히 설명하면 그래프에서 최대한 적은 가중치만 가지고 모든 노드를 연결하는 그래프를 추출한 것이라 생각하면 된다. 동작은 다음과 같다.

먼저 모든 간선들을 가중치 오름차순으로 정렬하고, 위에 나온 union-find를 사용해 작은 간선부터 양쪽 노드를 합쳐 나가는데, 간선의 양쪽 노드가 이미 같은 집합에 속하면 해당 간선은 건너 뛰고 다른 집합이면 MST 구성 간선으로 선택하는 것을 반복한다.

원리는 다음과 같다. 이 알고리즘을 수행하며 중간에 간선 A를 살펴본다고 하자. A를 선택하지 않는 경우는 간선 양쪽이 이미 연결돼있는 경우인데 당연히 무의미하다. A를 선택하는 경우의 고려 사항은, A대신 A보다 뒤에 정렬돼있는 B를 선택하는게 나을 가능성, 즉 A가 정말 최선인가이다. B가 지금 A가 연결하는 두 집합과 다른 집합을 연결하는 간선이라면 어차피 나중에 선택될 것이니까 논외고, B와 A가 같은 두 집합을 연결하는데 B가 더 나은 결과를 만들어내야 한다는 말인데, 이는 당연히 항상 A의 가중치가 B와 같거나 작기 때문에 말이 안된다. 어차피 B를 포함해서 B를 제외한 나머지 노드들도 다 같은 집합으로 연결된 채로 A랑 연결되기 때문에 간선 B를 선택한다고 더 효율적인 경로를 놓칠 가능성도 없다. 그럼 이제 코드를 보자.

p = [-1 for i in range(4)]

def find(x):
if p[x]>=0:
p[x]=find(p[x])
return p[x]
else:return x

def union(a,b):
pa = find(a)
pb = find(b)
if p[pa]>p[pb]:pa,pb=pb,pa
p[pa]+=p[pb]
p[pb]=pa

E = [
(1,0,1),(2,1,2),(3,2,3),(4,0,2),(5,0,3),(6,1,3)
]

def kruskal(E):
E = sorted(E)
res = []
for ew, v1, v2 in E:
pv1, pv2 = find(v1), find(v2)
if pv1==pv2:continue
union(pv1,pv2)
res.append((ew,v1,v2))
return res

print(kruskal(E))

보다시피 간선을 정렬하고 union-find를 사용해 MST 구성 간선을 추리고 있다. 노드 수가 N이고 간선 수가 E면 시간복잡도는 union-find가 O(log N)인데, 상술했듯이 일반적인 희소 그래프에서는 사실상 상수 시간이고, 무엇보다 정렬하는데 가장 큰 시간이 드는데 흔히 쓰이는 가장 효율적인 정렬 방식의 시간 복잡도를 따르면 Krskal 알고리즘의 시간복잡도는 O(E log E)라고 할 수 있다.

Prim’s algorithm

Prim 알고리즘도 마찬가지로 최소 신장 트리를 찾는 알고리즘이다. 다만 Kruskal 알고리즘과는 달리, Dijkstra 알고리즘과 흡사하게 작동한다.

먼저 임의의 노드에서 시작해서 연결되는 모든 간선들을 가중치 기준 최소 힙에 넣는다. 그리고 힙에서 간선 하나를 꺼내 현재 구성된 트리와 가장 가까운 노드를 트리에 포함시키면서 해당 노드의 간선을 탐색하며 힙에 넣고, 힙에서 또 하나를 꺼내 이런 과정을 반복한다. 그 와중에 간선이 이미 방문한 노드로 향하면 건너 뛴다. 이런 식으로 간선들을 선택해 MST를 추려낼 수 있다.

from collections import defaultdict
from heapq import heappush, heappop

# 예시 그래프 표현
E = [(5, 0, 3), (1, 0, 1), (2, 1, 2), (3, 2, 3), (4, 0, 2), (6, 1, 3)]
G=defaultdict(dict)
for ew, es, ed in E:
G[es][ed]=ew
G[ed][es]=ew

def prim(G):
res=[] # 결과 리스트
H=[] # 최소 힙
for ed in G[0]: # 최소 힙에 첫 노드 간선들 넣기
heappush(H,(G[0][ed],0,ed))
V=set([0]) # MST 포함 여부
cnt=1 # 트리에 포함된 노드 수
while cnt!=len(G): # 모든 노드가 포함될 때 까지
ew, es, ed = heappop(H) # 가중치, 출발지, 도착지
if ed in V:continue # 이미 MST에 포함됐으면 건너뜀
res.append((ew,es,ed))
cnt +=1
V.add(ed) # 트리에 포함
for ed2 in G[ed]: # 새 노드의 모든 간선들 힙에 추가
heappush(H,(G[ed][ed2],ed,ed2))
return res

print(prim(G))

간선 수를 E, 노드 수를 V라 했을 때, 시간 복잡도는 Dijkstra 알고리즘과 같은 과정을 거쳐 O(E log V)임을 알 수 있다.

--

--