import sys
def input():
return sys.stdin.readline().rstrip()
def find_parents(X):
if X == make_set[X]:
return X
make_set[X] = find_parents(make_set[X])
return make_set[X]
def union(x,y):
X = find_parents(x)
Y = find_parents(y)
if X !=Y:
if ranks[X]< ranks[Y]:
X,Y = Y,X
make_set[Y] = X
groups[X].extend(groups[Y])
if ranks[X] == ranks[Y]:
ranks[X] += 1
return True
return False
N,M = map(int,input().split())
edge_list = []
total_pay = 0
for _ in range(M):
x,y,pay = map(int,input().split())
if x>y:
x,y = y,x
edge_list.append((x,y,pay))
total_pay += pay
K = 10**9
past_pay = 0
answer = 0
make_set = [i for i in range(N+1)]
ranks = [1 for _ in range(N+1)]
groups = [[k] for k in range(N+1)]
edge_list.sort(key=lambda x : x[2])
while edge_list:
x,y,pay = edge_list.pop()
p_x,p_y = find_parents(x),find_parents(y)
if p_x!= p_y:
answer = answer + len(groups[p_x])*len(groups[p_y])*(total_pay-past_pay)
answer = answer%K
union(x,y)
past_pay += pay
print(answer)
구사과님의 분리집합 추천 문제에 있어서, 풀었던 문제였는데 푸는데 어려웠던 문제였다.
이 문제는 분리집합만을 이요해 푸는 문제인데, 푸는 아이디어는 다음과 같다.
모든 간선이 연결이 끊어져 있다고 생각을 하고, 간선의 비용이 가장 많은 비용 순으로 연결을 해주는 것이다.
문제에 주어진 Cost(u,v)는 u,v가 연결이 안될때까지 최소비용의 간선을 제거해주는 것이다.
즉 역으로 생각하면 간선의 비용이 가장 많은 비용 순으로 연결을 해줄때, 처음 u,v가 같은 집합이 되면,
그 전까지 연결된 비용을 제외한 나머지 비용들이, 이 u,v의 Cost가 될 수 있다.
즉 전체 간선의 연결비용을 구한 상태에서 비용이 높은 순으로 연결을 해주면서 연결이 되었을때 까지의 누적 비용을
뺀 값이 해당 cost(u,v)인 것을 알 수 있게 된다.
그러면 여기서 고민사항은 다음과 같다.
서로 단독집합일때에는 위와 같이 한다고 할때, 만약에 서로 다른 집합에 이미 속해있는 상황에서는 어떻게 해야할지,
일 것이다.
이 문제는 문제에서 cost(u,v)에서 u<v 이므로, 우리는 한가지 정보를 더 저장을 시켜줘야한다.
그것은 부모에 해당 집합에 속해있는 노드를 저장시켜놓는것이다.
그래서 우리는 각 부모들에 노드를 저장시켜주고, 그 노드의 수를 서로 곱해주면 두 집합에서 나올 수 있는
u,v의 집합의 개수와 동일 하다.
정리하면
(total_pay - past_pay) *len(group[parent_u]) *len(group[parent_v]) 로 나타낼수 있다.
위의 코드를 통해 AC를 받을 수 있었다. 이 문제의 아이디어를 떠오르는데 힘들었고, 구현을 하는데 어려움이 많았다.
아직 분리집합에 대해 잘 모르고, 응용하는법에 대해 잘 몰랐다는 것을 알게 되었다.
import sys
def input():
return sys.stdin.readline().rstrip()
def find_parents(X):
if X == make_set[X]:
return X
make_set[X] = find_parents(make_set[X])
return make_set[X]
def union(x,y):
X = find_parents(x)
Y = find_parents(y)
if X !=Y:
if ranks[X]< ranks[Y]:
X,Y = Y,X
make_set[Y] = X
groups[X] +=groups[Y]
if ranks[X] == ranks[Y]:
ranks[X] += 1
return True
return False
N,M = map(int,input().split())
edge_list = []
total_pay = 0
for _ in range(M):
x,y,pay = map(int,input().split())
if x>y:
x,y = y,x
edge_list.append((x,y,pay))
total_pay += pay
K = 10**9
past_pay = 0
answer = 0
make_set = [i for i in range(N+1)]
ranks = [1 for _ in range(N+1)]
groups = [ 1 for _ in range(N+1)]
edge_list.sort(key=lambda x : x[2])
while edge_list:
x,y,pay = edge_list.pop()
p_x,p_y = find_parents(x),find_parents(y)
if p_x!= p_y:
answer = answer + groups[p_x]*groups[p_y]*(total_pay-past_pay)
answer = answer%K
union(p_x,p_y)
past_pay += pay
print(answer)
이 코드는 노드를 전부 저장하는게 아닌, 노드의 개수를 저장해주는 방식으로 바꾼 코드이다.
'알고리즘 > 백준' 카테고리의 다른 글
[BOJ/백준] 15661 링크와 스타트 (0) | 2021.06.29 |
---|---|
[BOJ/백준] 10711 모래성 (0) | 2021.06.29 |
[BOJ/백준] 1398 동전 (0) | 2021.06.29 |
[BOJ/백준] 1045 도로 (0) | 2021.06.29 |
[BOJ/백준] 2250 트리의 높이와 너비 (0) | 2021.06.22 |