def find_parent(ind):
if make_set[ind] == ind:
return ind
else:
make_set[ind] = find_parent(make_set[ind])
return make_set[ind]
def union(x,y):
X = find_parent(x)
Y = find_parent(y)
if X == Y:
return 0
else:
if child_cnt[X]< child_cnt[Y]:
make_set[X] = Y
return_value = child_cnt[X] * child_cnt[Y]
child_cnt[Y] += child_cnt[X]
else:
make_set[Y] = X
return_value = child_cnt[X] * child_cnt[Y]
child_cnt[X] += child_cnt[Y]
return return_value
N,M,Q = map(int,input().split())
graph = [{} for i in range(N+1)]
make_set = [i for i in range(N+1)]
child_cnt = [1 for i in range(N+1)]
connect_input = [[],]
for _ in range(M):
x,y = map(int,input().split())
graph[x][y] = 1
graph[y][x] = 1
connect_input.append((x,y))
result = 0
disconnet_list = []
for _ in range(Q):
ind = int(input())
x,y = connect_input[ind]
graph[x][y] = 0
graph[y][x] = 0
disconnet_list.append(ind)
for i in range(1,M+1):
if i not in disconnet_list:
x,y = connect_input[i]
union(x,y)
while disconnet_list:
ind = disconnet_list.pop()
x,y = connect_input[ind]
result += union(x,y)
print(result)
이 문제는 역으로 생각하는 편이 제대로 푸는 방식이었다 윗 코드는 첫 풀이이다.
문제를 푸는 방식은 다음과 같다. 문제에서 주어진 끊어진 조건들을 전부 끊어졌다고 가정을 하고
끝에서부터 하나씩 연결을 해가면서 반대로 추적을 하는것이다.
그러므로, 먼저 전체 연결리스트를 들어온 순서에 맞게 index에 알맞게 저장을 해준다.
그리고 난뒤에 끊어진 목록들을 따로 저장을 해두고, 끊어지지 않은 연결들끼리 서로 union find를 해준다.
union_find를 하면서, 각자의 노드 개수를 저장해주는 리스트를 따로 만들어두고,
서로 다른 집합이 합쳐질때 그 두 개의 노드의 개수를 합쳐주는 로직을 해주면 된다.
이렇게 연결이 된 간선들로 union find를 1차적으로 진행을 한다.
그리고 끝에서부터 끊어진 간선들끼리 연결을 해준다. 그러면 그 간선을 끊기전 모습으로 돌아갈 수 있다.
이렇게 하면서 서로의 집합이 같으면 0을 반환하고, 그게 아니면 서로의 집합의 개수를 곱한 수를 반환하도록 했다.
import sys
input = sys.stdin.readline
def find_parent(ind):
if make_set[ind] == ind:
return ind
else:
make_set[ind] = find_parent(make_set[ind])
return make_set[ind]
def union(x,y):
X = find_parent(x)
Y = find_parent(y)
if X == Y:
return 0
else:
if rank[X] < rank[Y]:
X,Y = Y,X
size_a,size_b = rank[X],rank[Y]
rank[X] += rank[Y]
make_set[Y] = X
return size_a*size_b
N,M,Q = map(int,input().split())
make_set = [i for i in range(N+1)]
rank = [1 for _ in range(N+1)]
connect_input = [[],]
check = [True]*(M+1)
for _ in range(M):
x,y = map(int,input().split())
connect_input.append((x,y))
result = 0
disconnet_list = []
for _ in range(Q):
ind = int(input())
disconnet_list.append(ind)
check[ind] = False
for i in range(1,M+1):
if check[i]:
x,y = connect_input[i]
union(x,y)
while disconnet_list:
ind = disconnet_list.pop()
x,y = connect_input[ind]
result += union(x,y)
print(result)
좀 더 깔끔하고 빠른 풀이 방식이다. 첫 풀이 코드같은경우엔 느렸는데
그 이유는 간선이 연결됬는지 안됬는지를 구분하는데 not in 으로 했기 때문에, O(N)의 시간이 걸려서 느려진 문제가 있었다.
그래서 간선들이 연결이 됬는지 안됬는지를 구분하는 리스트를 만들어두고 바로 확인이 가능하도록 했다.
https://www.secmem.org/blog/2021/03/21/Disjoint-Set-Union-find/
설명은 잘 못하므로 위의 링크에 있는 Union_find를 보면 알겠지만, rank compression을 활용해서 시간을 좀 더 줄였다.
Union find는 크루스칼알고리즘에서 처음으로 알게 되었는데, 크루스칼에서만 쓰이는줄 알았는데,
생각외로 단독으로 쓰이는 곳이 많았다. 이걸 짜는데 어색해서 크루스칼 알고리즘을 잘 안쓰는데,
좀 더 숙달되도록 노력해야겠다.
'알고리즘 > 백준' 카테고리의 다른 글
[BOJ/백준] 2239 스도쿠 (0) | 2021.05.19 |
---|---|
[BOJ/백준] 1874 스택 수열 (0) | 2021.05.19 |
[BOJ/백준] 1103 게임 (0) | 2021.05.18 |
[BOJ/백준] 5875 오타 (0) | 2021.05.17 |
[BOJ/백준] 9944 NxM 보드 완주하기 (0) | 2021.05.17 |