def find_parent(ind):
    if make_set[ind] == ind:
        return ind
    else:
        make_set[ind] = find_parent(make_set[ind])
        return make_set[ind]

def union(x,y):
    X = find_parent(x)
    Y = find_parent(y)
    if X == Y:
        return 0

    else:
        if child_cnt[X]< child_cnt[Y]:
            make_set[X] = Y
            return_value = child_cnt[X] * child_cnt[Y]
            child_cnt[Y] += child_cnt[X]
        else:
            make_set[Y] = X
            return_value = child_cnt[X] * child_cnt[Y]
            child_cnt[X] += child_cnt[Y]

        return return_value
            
        

N,M,Q = map(int,input().split())


graph = [{} for i in range(N+1)]

make_set = [i for i in range(N+1)]
child_cnt = [1 for i in range(N+1)]
connect_input = [[],]
for _ in range(M):
    x,y = map(int,input().split())
    graph[x][y] = 1
    graph[y][x] = 1
    connect_input.append((x,y))

result = 0
disconnet_list = []

for _ in range(Q):
    ind = int(input())
    x,y = connect_input[ind]
    graph[x][y] = 0
    graph[y][x] = 0
    disconnet_list.append(ind)


for i in range(1,M+1):
    if i not in disconnet_list:
        x,y = connect_input[i]
        union(x,y)


while disconnet_list:
    ind = disconnet_list.pop()
    x,y = connect_input[ind]
    result += union(x,y)
print(result)

 

이 문제는 역으로 생각하는 편이 제대로 푸는 방식이었다 윗 코드는 첫 풀이이다.

 

문제를 푸는 방식은 다음과 같다. 문제에서 주어진 끊어진 조건들을 전부 끊어졌다고 가정을 하고

 

끝에서부터 하나씩 연결을 해가면서 반대로 추적을 하는것이다.

 

그러므로, 먼저 전체 연결리스트를 들어온 순서에 맞게 index에 알맞게 저장을 해준다.

 

그리고 난뒤에 끊어진 목록들을 따로 저장을 해두고, 끊어지지 않은 연결들끼리 서로 union find를 해준다.

 

 

union_find를 하면서, 각자의 노드 개수를 저장해주는 리스트를 따로 만들어두고,

 

서로 다른 집합이 합쳐질때 그 두 개의 노드의 개수를 합쳐주는 로직을 해주면 된다.

 

이렇게 연결이 된 간선들로 union find를 1차적으로 진행을 한다.

 

 

그리고 끝에서부터 끊어진 간선들끼리 연결을 해준다. 그러면 그 간선을 끊기전 모습으로 돌아갈 수 있다.

 

이렇게 하면서 서로의 집합이 같으면 0을 반환하고, 그게 아니면 서로의 집합의 개수를 곱한 수를 반환하도록 했다.

 

import sys

input = sys.stdin.readline


def find_parent(ind):
    if make_set[ind] == ind:
        return ind
    else:
        make_set[ind] = find_parent(make_set[ind])
        return make_set[ind]

def union(x,y):
    X = find_parent(x)
    Y = find_parent(y)
    if X == Y:
        return 0

    else:
        if rank[X] < rank[Y]:
            X,Y = Y,X

        size_a,size_b = rank[X],rank[Y]
        rank[X] += rank[Y]
        make_set[Y] = X
        return size_a*size_b
            
        

N,M,Q = map(int,input().split())



make_set = [i for i in range(N+1)]
rank = [1 for _ in range(N+1)]
connect_input = [[],]
check = [True]*(M+1)
for _ in range(M):
    x,y = map(int,input().split())
    connect_input.append((x,y))

result = 0
disconnet_list = []

for _ in range(Q):
    ind = int(input())
    disconnet_list.append(ind)
    check[ind] = False


for i in range(1,M+1):
    if check[i]:
        x,y = connect_input[i]
        union(x,y)


while disconnet_list:
    ind = disconnet_list.pop()
    x,y = connect_input[ind]
    result += union(x,y)
print(result)

 

좀 더 깔끔하고 빠른 풀이 방식이다. 첫 풀이 코드같은경우엔 느렸는데

 

그 이유는 간선이 연결됬는지 안됬는지를 구분하는데 not in 으로 했기 때문에, O(N)의 시간이 걸려서 느려진 문제가 있었다.

 

 

그래서 간선들이 연결이 됬는지 안됬는지를 구분하는 리스트를 만들어두고 바로 확인이 가능하도록 했다.

 

 

https://www.secmem.org/blog/2021/03/21/Disjoint-Set-Union-find/

 

Disjoint Set & Union-find

Disjoint Set Disjoint Set(서로소 집합, 분리 집합)이란 서로 공통된 원소를 가지고 있지 않은 두 개 이상의 집합을 말합니다. 모든 집합들 사이에 공통된 원소가 존재하지 않는다는 것을, 각 원소들은

www.secmem.org

 

설명은 잘 못하므로 위의 링크에 있는 Union_find를 보면 알겠지만, rank compression을 활용해서 시간을 좀 더 줄였다.

 

 

Union find는 크루스칼알고리즘에서 처음으로 알게 되었는데, 크루스칼에서만 쓰이는줄 알았는데,

 

생각외로 단독으로 쓰이는 곳이 많았다. 이걸 짜는데 어색해서 크루스칼 알고리즘을 잘 안쓰는데,

 

좀 더 숙달되도록 노력해야겠다.

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 2239 스도쿠  (0) 2021.05.19
[BOJ/백준] 1874 스택 수열  (0) 2021.05.19
[BOJ/백준] 1103 게임  (0) 2021.05.18
[BOJ/백준] 5875 오타  (0) 2021.05.17
[BOJ/백준] 9944 NxM 보드 완주하기  (0) 2021.05.17

+ Recent posts