알고리즘/백준
[BOJ/백준] 1414 불우이웃 돕기
mmmlee
2021. 6. 22. 21:10
import sys
import heapq
def convert_Numner(x):
global total_sum,INF
if x.islower():
num = ord(x) - ord('a') + 1
total_sum += num
elif x.isupper():
num = ord(x) - ord('A')+ 27
total_sum += num
else:
num = INF
total_sum += 0
return num
def input():
return sys.stdin.readline().rstrip()
N = int(input())
total_sum = 0
INF = float('inf')
arr = [list(map(convert_Numner,list(input()))) for _ in range(N)]
cnt = 0
mst_dis = 0
distance_list = [INF for _ in range(N)]
visited = [False for _ in range(N)]
node_list = []
heapq.heappush(node_list,(0,0))
distance_list[0] = 0
while node_list:
cur_dis,cur_node = heapq.heappop(node_list)
if visited[cur_node]:continue
cnt += 1
mst_dis += cur_dis
visited[cur_node] = True
for next_node in range(N):
if visited[next_node]:continue
min_value = min(arr[cur_node][next_node],arr[next_node][cur_node])
if distance_list[next_node] > min_value:
distance_list[next_node] = min_value
heapq.heappush(node_list,(min_value,next_node))
if cnt == N:
print(total_sum-mst_dis)
else:
print(-1)
처음에 프림알고리즘으로 풀었다. MST에서 익숙한 알고리즘이 프림알고리즘이라, 우선적으로 푼 방식이다.
여기서 여러번 틀렸었는데 그 이유는 다음과 같다. 우리가 연결을 하는 통로는 무방향이다.
즉 1번방에서 2번방으로 연결하는 랜선과 2번방에서 1번방으로 연결하는 랜선은 둘중 하나만 선택을 하면, 둘이 연결이 된것으로 본다.
이 문제를 풀때 무조건 0번방부터 시작을 했기때문에
3
zzz
abc
abc
와 같은 입력이 들어왔을때, 오히려 더 짧은 반대편방에서 오는 경우를 체크하지 못했다.
이러한 문제점을 해결하기 위해, 거리를 비교할때, 양방향의 최소값으로 비교를 하고, 그값을 넣어주는 방식으로 했다.
이런식이 아니라, 처음부터 graph라는 딕셔너리를 만들어서 최소값을 저장해주어도 된다.
이 점만 주의하면, 나머지는 일반적인 프림알고리즘을 이용한 MST 문제이다.
cnt개수가 N보다 부족할시에는 모든 노드들이 연결이 된것이 아니므로, -1을 출력해주면 된다.
import sys
def convert_Numner(x):
global total_sum,INF
if x.islower():
num = ord(x) - ord('a') + 1
total_sum += num
elif x.isupper():
num = ord(x) - ord('A')+ 27
total_sum += num
else:
num = INF
total_sum += 0
return num
def find_parents(X):
if X == make_set[X]:
return X
make_set[X] = find_parents(make_set[X])
return make_set[X]
def union(x,y):
X = find_parents(x)
Y = find_parents(y)
if X !=Y:
if ranks[X]< ranks[Y]:
X,Y = Y,X
make_set[Y] = X
if ranks[X] == ranks[Y]:
ranks[X] += 1
return True
return False
def input():
return sys.stdin.readline().rstrip()
N = int(input())
total_sum = 0
INF = float('inf')
edge_list = []
for x in range(N):
input_string = input()
temp = []
for y in range(N):
num = convert_Numner(input_string[y])
if x == y:continue
if num == INF:continue
edge_list.append((num,x,y))
edge_list.sort(reverse=True)
cnt = 1
result = 0
make_set = [i for i in range(N)]
ranks = [1 for _ in range(N)]
while cnt <N and edge_list:
pay,node_a,node_b = edge_list.pop()
if union(node_a,node_b):
cnt += 1
result += pay
if cnt == N:
print(total_sum-result)
else:
print(-1)
크루스칼을 이용한 풀이이다. 푸는 방법은 위와 동일하다 대신 edge_list에 x와 y가 같은경우와 연결이 되지않는 경우는 제외시켰다.