CS/알고리즘_개념

최소신장트리(Spanning Tree) : 크루스칼 알고리즘

Jedy_Kim 2021. 6. 7. 18:58
728x90

최소 신장 트리의 이해

1. 신장 트리 란?

  • Spanning Tree, 또는 신장 트리 라고 불리움 (Spanning Tree가 보다 자연스러워 보임)
  • 원래의 그래프의 모든 노드가 연결되어 있으면서 트리의 속성을 만족하는 그래프
  • 신장 트리의 조건
    • 본래의 그래프의 모든 노드를 포함해야 함
    • 모든 노드가 서로 연결
    • 트리의 속성을 만족시킴 (사이클이 존재하지 않음)

2. 최소 신장 트리

  • Minimum Spanning Tree, MST 라고 불리움
  • 가능한 Spanning Tree 중에서, 간선의 가중치 합이 최소인 Spanning Tree를 지칭함

3. 최소 신장 트리 알고리즘

  • 그래프에서 최소 신장 트리를 찾을 수 있는 알고리즘이 존재함
  • 대표적인 최소 신장 트리 알고리즘
    • Kruskal’s algorithm (크루스칼 알고리즘), Prim's algorithm (프림 알고리즘)

4. 크루스칼 알고리즘 (Kruskal's algorithm)

  1. 모든 정점을 독립적인 집합으로 만든다.
  2. 모든 간선을 비용을 기준으로 정렬하고, 비용이 작은 간선부터 양 끝의 두 정점을 비교한다.
  3. 두 정점의 최상위 정점을 확인하고, 서로 다를 경우 두 정점을 연결한다. (최소 신장 트리는 사이클이 없으므로, 사이클이 생기지 않도록 하는 것임)

탐욕 알고리즘을 기초로 하고 있음 (당장 눈 앞의 최소 비용을 선택해서, 결과적으로 최적의 솔루션을 찾음)

5. Union-Find 알고리즘

  • Disjoint Set을 표현할 때 사용하는 알고리즘으로 트리 구조를 활용하는 알고리즘
  • 간단하게, 노드들 중에 연결된 노드를 찾거나, 노드들을 서로 연결할 때 (합칠 때) 사용
  • Disjoint Set이란
    • 서로 중복되지 않는 부분 집합들로 나눠진 원소들에 대한 정보를 저장하고 조작하는 자료구조
    • 공통 원소가 없는 (서로소) 상호 배타적인 부분 집합들로 나눠진 원소들에 대한 자료구조를 의미함
    • Disjoint Set = 서로소 집합 자료구조
  1. 초기화
    • n 개의 원소가 개별 집합으로 이뤄지도록 초기화
  2. Union
    • 두 개별 집합을 하나의 집합으로 합침, 두 트리를 하나의 트리로 만듬
  3. Find
    • 여러 노드가 존재할 때, 두 개의 노드를 선택해서, 현재 두 노드가 서로 같은 그래프에 속하는지 판별하기 위해, 각 그룹의 최상단 원소 (즉, 루트 노드)를 확인

Union-Find 알고리즘의 고려할 점

  • Union 순서에 따라서, 최악의 경우 링크드 리스트와 같은 형태가 될 수 있음.
  • 이 때는 Find/Union 시 계산량이 O(N) 이 될 수 있으므로, 해당 문제를 해결하기 위해, union-by-rank, path compression 기법을 사용함

union-by-rank 기법

    • 각 트리에 대해 높이(rank)를 기억해 두고,
    • Union시 두 트리의 높이(rank)가 다르면, 높이가 작은 트리를 높이가 큰 트리에 붙임 (즉, 높이가 큰 트리의 루트 노드가 합친 집합의 루트 노드가 되게 함)

    • 높이가 h - 1 인 두 개의 트리를 합칠 때는 한 쪽의 트리 높이를 1 증가시켜주고, 다른 쪽의 트리를 해당 트리에 붙여줌

  • 초기화시, 모든 원소는 높이(rank) 가 0 인 개별 집합인 상태에서, 하나씩 원소를 합칠 때, union-by-rank 기법을 사용한다면,
    • 높이가 h 인 트리가 만들어지려면, 높이가 h - 1 인 두 개의 트리가 합쳐져야 함
    • 높이가 h - 1 인 트리를 만들기 위해 최소 n개의 원소가 필요하다면, 높이가 h 인 트리가 만들어지기 위해서는 최소 2n개의 원소가 필요함
    • 따라서 union-by-rank 기법을 사용하면, union/find 연산의 시간복잡도는 O(N) 이 아닌, 𝑂(𝑙𝑜𝑔𝑁)O(logN) 로 낮출 수 있음

path compression

  • Find를 실행한 노드에서 거쳐간 노드를 루트에 다이렉트로 연결하는 기법
  • Find를 실행한 노드는 이후부터는 루트 노드를 한번에 알 수 있음

  • union-by-rank 와 path compression 기법 사용시 시간 복잡도는 다음 계산식을 만족함이 증명되었음
    • 𝑂(𝑀𝑙𝑜𝑔𝑁)O(Mlog∗N)
    • 𝑙𝑜𝑔𝑁log∗N 은 다음 값을 가짐이 증명되었음
      • N이 265536265536 값을 가지더라도, 𝑙𝑜𝑔𝑁log∗N 의 값이 5의 값을 가지므로, 거의 O(1), 즉 상수값에 가깝다고 볼 수 있음

N𝑙𝑜𝑔𝑁log∗N

1 0
2 1
4 2
16 3
65536 4
265536265536 5

 

 

#코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from collections import defaultdict
import sys
 
'''
        (1)     (4)     (1)     (2)
     1 ----- 2 ----- 4 ----- 5 ----- 8
      \     /         \     /       /
   (2) \   / (3)   (2) \   / (2)   / (4)
         3               6  ----- 7        
                              (3)
'''
 
# x 가 속한 트리에서의 루트를 찾는다.
def Find(x):  
  # 내 자신이 루트인지를 판단.
  if x == parent[x]:
    return x
  else:
    y = Find(parent[x])
    parent[x] = y
    return y
 
# 시작점과 끝점이 다른 그룹에 있으면 True, 같은 그룹에 있으면 False
def Union(a, b):
  global parent
   
  rootA = Find(a) # a가 속한 트리의 루트를 반환
  rootB = Find(b) # b가 속한 트리의 루트를 반환
  
  if rootA == rootB:
    return False
  else:
    # 다른 그룹이라면 
    parent[rootA] = rootB # a와 b를 같은 그룹으로 묶는다.
    return True
  
 
if __name__ == "__main__":
  input = sys.stdin.readline
  '''
  설계
  1. 그래프를 입력 받는다
  2. 간선 가중치가 작은 순서대로 정렬을 한다.
  3. 간선 가중치가 작은 간선부터 차례대로 선택하려는 시도를 해본다.
    3-1. 시작점과 끝점이 같은 그룹에 속한다면, 그냥 넘어간다 (사이클 발생)
    3-2. 그게 아니라면, 이 간선을 선택한다.(이 간선의 가중치를 더한다.)
         시작점과 끝점을 같은 그룹으로 만들어 준다.
  ''' 
  n, m  = map(int, input().split())
  edgeList = []
  parent = [0]*(n+1# parent[x] : x의 부모노드의 번호
  for i in range(1, n+1):
    parent[i] = i
  
  result = 0
  
  for _ in range(m):
    p, q, c = map(int, input().split())
    # p : 시작점, q : 끝점, c : 가중치
    edgeList.append((p, q, c)) 
  
  edgeList = sorted(edgeList, key=lambda edge: edge[2])
  
  for i in range(m):
    # edgeList[i]의 간선을 선택하려고 시도할 것임.
    
    if Union(edgeList[i][0], edgeList[i][1]): # 만약 시작점과 끝점이 다른 그룹에 있다면?
      result += edgeList[i][2
      
  print(result)
cs
입력
#첫째 줄 : 정점의 갯수, 간선의 갯수
#둘째 줄 ~ : 시작점, 끝점, 가중치
8 10
1 2 1
1 3 2
2 3 3
2 4 4
4 5 1
4 6 2
5 6 2
5 8 2
6 7 3
7 8 4

출력
15

 

# 복습

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import sys
 
'''
     (1)     (4)     (1)     (2)
  1 ----- 2 ----- 4 ----- 5 ----- 8
   \     /         \     /       /
(2) \   / (3)   (2) \   / (2)   / (4)
     \ /             \ /       /
      3               6 ----- 7  
 
'''
'''
  (1) 그래프를 입력을 받는다.
  (2) 간선 가중치가 작은 순서대로 정렬을 한다.
  (3) 간선 가중치가 작은 간선부터 차례대로 선택하려는 시도를 해본다.
    (3-1) 시작점과 끝점이 같은 그룹에 속한다면, 그냥 넘어간다.
    (3-2) 그게 아니라면, 이 간선을 선택한다. 이 간선의 가중치를 더한다.
          시작점과 끝점을 같은 그룹으로 만들어 준다.
'''
 
# v1, v2가 같은 그룹인지를 판단
def Union(v1, v2):
  global parent 
   
  root_a = Find(v1) # v1이 속한 트리의 루트 값
  root_b = Find(v2) # v2가 속한 트리의 루트 값
  
  if root_a == root_b:# 만약 같은 그룹이라면, False를 반환.
    return False
  else# 그게 아니라면, v1과 v2를 같은 그룹으로 만들고 True를 반환.
    parent[root_a] = root_b
    return True
    
    
# 정점 v가 속한 루트를 반환
def Find(v):
  global parent
  # 자기 자신이 정점인 경우
  if v == parent[v]: return v
  else:
    y = Find(parent[v])
    parent[v] = y
    return y
  
  
if __name__ == '__main__':
  input = sys.stdin.readline
  n, m  = map(int, input().split())
  # 그래프
  my_graph = []
  # parent[x] : x의 부모노드의 번호 -> 그룹의 대표값
  parent   = [ i for i in range(n + 1)] 
  # 답(최솟값)
  result   = 0
  
  # (1) 그래프를 입력을 받는다.
  for _ in range(m):
    a, b, c = map(int, input().split())
    my_graph.append((a, b, c))
    
  # (2) 간선 가중치가 작은 순서대로 정렬을 한다.
  my_graph = sorted(my_graph, key = lambda my_graph : my_graph[2])
  
  # (3) 간선 가중치가 작은 간선부터 차례대로 선택하려는 시도를 해본다.
  for i in range(m):
    if Union(my_graph[i][0], my_graph[i][1]): # 시작점과 끝점이 다른 그룹에 있다면? 
      result += my_graph[i][2]
  
  print(result)
  
 
'''
8 10
1 2 1
1 3 2
2 3 3
2 4 4
4 5 1
4 6 2
5 6 2
5 8 2
6 7 3
7 8 4
 
15
'''
cs

 



반응형