import sys
from collections import defaultdict
import heapq
def input():
    return sys.stdin.readline().rstrip()
N,M = map(int,input().split())



graph = [defaultdict(list) for _ in range(N+1)]
for _ in range(M):
    x,y,pay = map(int,input().split())
    graph[x][y].append((pay,True))
    graph[y][x].append((pay,False))


node_list = []
INF = float('inf')
distance_list = [INF for _ in range(N+1)]
distance_list[1] = 0
heapq.heappush(node_list,(0,1,1,True))
result = []

while node_list:
    cur_dis, cur_node , parent_node , cur_flag = heapq.heappop(node_list)
    if cur_dis > distance_list[cur_node]:
        continue
    if cur_node != parent_node:
        if cur_flag:
            result.append((parent_node, cur_node))
        else:
            result.append((cur_node, parent_node))
    for next_node in graph[cur_node]:


        for pay,flag in graph[cur_node][next_node]:
            if cur_dis + pay < distance_list[next_node]:
                distance_list[next_node] = cur_dis + pay
                heapq.heappush(node_list,(distance_list[next_node], next_node , cur_node , flag))
    
print(len(result))
for row in result:
    print(*row)

처음으로 풀었던 방식이다. 문제의 출력을 입력에 주어진 간선으로 출력해야하는 줄 알고, 좀 복잡하게 풀었다.

 

이 문제를 처음 본 순간 MST 문제 인줄 알았는데, MST 문제는 아니였다.

 

1번 조건만 보고 모든 노드들이 연결되는 최소인줄 알고 MST인줄 알았는데,

 

2번 조건을 찬찬히 읽어보니 다익스트라 문제였다.

 

 

네트워크를 복구해서 통신이 가능하도록 만드는 것도 중요하지만, 해커에게 공격을 받았을 때 보안 패킷을 전송하는 데 걸리는 시간도 중요한 문제가 된다. 따라서 슈퍼컴퓨터가 다른 컴퓨터들과 통신하는데 걸리는 최소 시간이, 원래의 네트워크에서 통신하는데 걸리는 최소 시간보다 커져서는 안 된다.

 

그 2번 조건은 다음과 같다.

 

슈퍼컴퓨터를 기준으로 모든 컴퓨터들은 최소 시간을 보장해야한다.

 

즉 슈퍼컴퓨터를 기준으로 모든 노드들에 대한 최소 시간의 다익스트라를 구해야하는 문제이다.

 

이 점만 주의하면 일반적인 다익스트라 풀이와 동일하다.

 

 

import sys
from collections import defaultdict
import heapq
def input():
    return sys.stdin.readline().rstrip()
N,M = map(int,input().split())



graph = [defaultdict(list) for _ in range(N+1)]

for _ in range(M):
    x,y,pay = map(int,input().split())
    graph[x][y].append(pay)
    graph[y][x].append(pay)


node_list = []
INF = float('inf')
distance_list = [INF for _ in range(N+1)]
distance_list[1] = 0
heapq.heappush(node_list,(0,1,1))
result = []

while node_list:
    cur_dis, cur_node , parent_node  = heapq.heappop(node_list)
    if cur_dis > distance_list[cur_node]:
        continue
    if cur_node != parent_node:
        result.append((parent_node, cur_node))
    for next_node in graph[cur_node]:


        for pay in graph[cur_node][next_node]:
            if cur_dis + pay < distance_list[next_node]:
                distance_list[next_node] = cur_dis + pay
                heapq.heappush(node_list,(distance_list[next_node], next_node , cur_node ))

print(len(result))

for row in result:
    print(*row)

개선한 풀이이다.

+ Recent posts