import sys
input = sys.stdin.readline
def dfs(node,flag,*args):
stack = [(node,0)]
distance_list = [0 for _ in range(N+1)]
visited = [False for _ in range(N+1)]
visited[node] = True
if not flag:
visited[args[0]] = True
while stack:
node,distance = stack.pop()
distance_list[node] = distance
for next_node in graph[node]:
if not visited[next_node]:
visited[next_node] = True
stack.append((next_node,distance+graph[node][next_node]))
temp = []
for i in range(1,N+1):
temp.append((distance_list[i],i))
temp.sort(key=lambda x :-x[0])
if flag:
return_value = temp[0][1]
else:
return_value = temp[0][0]
return return_value
N = int(input())
graph = [{} for _ in range(N+1)]
for _ in range(N-1):
x,y,pay = map(int,input().split())
graph[x][y] = pay
graph[y][x] = pay
first_point = dfs(1,True)
second_point = dfs(first_point,True)
first_value = dfs(first_point,False,second_point)
second_value = dfs(second_point,False,first_point)
print(max(first_value,second_value))

 이 문제는 트리의 지름을 구하는 방법을 응요한 문제이다.

 

총 4번의 dfs를 통해 두번째 트리의 지름을 알 수 있다.

 

첫번째 dfs를 통해 아무지점에서 가장 먼 지점을 구한다.

 

이 지점은 지름을 구성하는 한 점이 될것이다.

 

 

두번째 dfs를 통해 첫번째 dfs에서 구한 점에서 가장 먼 점을 구한다.

 

그러면 이 점은 트리의 지름을 구성하는 가장 먼 점이 될것이다.

 

 

이렇게 2번의 dfs를 통해 우리는 가장 먼 점 2개를 구했다.

 

그러면 이 2점을 각각 dfs를 돌려, 가장 먼 점에서 2번째인 것을 찾아낸다.

 

그리고 이렇게 두 거리를 구한 뒤 그 중에 더 큰 점을 선택하면 되는 문제이다.

 

 

import sys
input = sys.stdin.readline
def dfs(node):
distance_list = [0 for _ in range(N+1)]
visited = [False for _ in range(N+1)]
visited[node] = True
stack = [(node,0)]
while stack:
node,distance = stack.pop()
distance_list[node] = distance
for next_node in graph[node]:
if not visited[next_node]:
visited[next_node] = True
stack.append((next_node,graph[node][next_node]+distance))
return distance_list
N = int(input())
graph = [{} for _ in range(N+1)]
for _ in range(N-1):
x,y,pay = map(int,input().split())
graph[x][y] = pay
graph[y][x] = pay
distance1 = dfs(1)
far_point1 = distance1.index(max(distance1))
distance2 = dfs(far_point1)
far_point2 = distance2.index(max(distance2))
distance3 = dfs(far_point2)
result = sorted(distance2+distance3)[-3]
print(result)

단 3번의 dfs를 통해 구하는 방법도 있다.

 

이 방법은 지름을 구성하는 2점의 전체길이를 한 리스트로 정하고 뒤에서 3번째를 출력해주는 것이다.

 

왜냐하면 한 리스트당 한 지점에서 가장 먼 지점은 서로 자신들이기 때문에, 그 2점을 제외하고 그 다음번으로 큰 것이

 

두번째 트리의 지름의 답이된다.

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 20924 트리의 기둥과 가지  (0) 2021.06.07
[BOJ/백준] 20300 서강근육맨  (0) 2021.06.07
[BOJ/백준] 15644 구슬 탈출 3  (0) 2021.06.07
[BOJ/백준] 3056 007  (0) 2021.06.07
[BOJ/백준] 16188 달빛 여우  (0) 2021.06.07
import math
import sys
def dfs(node):
stack = [(node,0,0)]
visited = [True]*(N+1)
distance = 0
max_node = -1
min_time = float('inf')
visited[node] = False
while stack:
node,dis,time = stack.pop()
if dis > distance:
distance = dis
max_node = node
min_time = time
elif dis == distance and min_time > time:
max_node = node
min_time = time
for next_node in graph[node]:
if visited[next_node]:
new_dis = dis + 1
new_time = time + graph[node][next_node]
visited[next_node] = False
stack.append((next_node,new_dis,new_time))
return [max_node,distance,min_time]
input = sys.stdin.readline
N,T = map(int,input().split())
graph = [{} for _ in range(N+1)]
for _ in range(N-1):
x,y,time = map(int,input().split())
graph[x][y] = time
graph[y][x] = time
far_node_info = dfs(1)
result = dfs(far_node_info[0])
print(math.ceil(result[2]/T))

트리의 지름 이 문제를 풀어보면 해당 문제를 좀 더 쉽게 풀 수 있다.

문제의 조건은 총 2가지이다. 문제를 가장 많이 풀어야하며, 그 푸는데 시간이 가장 짧아야한다.

이 문제는 트리구조이다. 그래서 문제를 가장 많이 푼다는 것은 트리의 지름을 구하는것과 같다.

그러므로 처음 dfs로 아무 노드에서 가장 먼 노드를 찾고, 그 노드에서 가장 먼 노드들을 찾으면 그게 트리의 지름이 된다.

이걸 응용해서, 첫 dfs로 가장 먼 노드 하나를 찾는다.

그리고 두번째 dfs로 찾은 가장 먼 노드를 기준으로, dfs를 돌려서 깊이가 가장 깊은 노드들을 찾는다. 그리고 그 중에서, 시간이 가장 짧은 것을 선택해주면 된다.

이 문제는 1967번 문제의 트리의 지름을 응용한 문제이고, 트리의 특성에 대해 알고 있으면 쉽게 풀 수 있었던 문제였지만,

처음에 트리의 지름에 대한 아이디어를 얻지 못해서 어려웠던 문제이다.

import sys
input = sys.stdin.readline
def check(num):
visited[num] = True
stack = [num]
while stack:
node = stack.pop(0)
for next_node in tree[node]:
if tree[node][next_node] == 1:
if not visited[next_node]:
visited[next_node] = True
stack.append(next_node)
tree[node][next_node] = 0
tree[next_node][node] = 0
else:
return False
return True
tc = 1
while True:
N,M = map(int,input().split())
if N+M == 0:
break
parent_cnt = [0]*(N+1)
tree = [{} for _ in range(N+1)]
for _ in range(M):
x,y = map(int,input().split())
tree[x][y] = 1
tree[y][x] = 1
cnt = 0
visited = [False]*(N+1)
for num in range(1,N+1):
if not visited[num]:
if check(num):
cnt += 1
if cnt == 0:
print(f'Case {tc}: No trees.')
elif cnt == 1:
print(f'Case {tc}: There is one tree.')
else:
print(f'Case {tc}: A forest of {cnt} trees.')
tc += 1
import sys
input = sys.stdin.readline
N = int(input())
tree = [[-1 for _ in range(2)] for _ in range(N+1)]
for i in range(1,N+1):
left_ndoe,right_node = map(int,input().split())
tree[i][0] = left_ndoe
tree[i][1] = right_node
K = int(input())
cu_node = 1
while K >=0:
left_or_right = K%2
if tree[cu_node][0] != -1 and tree[cu_node][1] != -1:
if left_or_right:
cu_node = tree[cu_node][0]
else:
cu_node = tree[cu_node][1]
K = K//2 + left_or_right
else:
if tree[cu_node][0] == -1 and tree[cu_node][1] == -1:
break
elif tree[cu_node][1] == -1:
cu_node = tree[cu_node][0]
else:
cu_node = tree[cu_node][1]
print(cu_node)

 

import sys
sys.setrecursionlimit(100001)
def dfs(node):
visited[node] = True
child_node = []
for next_node in graph[node]:
if not visited[next_node]:
child_node.append(next_node)
if not len(child_node):
dp[node][0] = 1
return
else:
for child in child_node:
dfs(child)
for child in child_node:
if dp[child][0]:
dp[node][1] = 1
break
else:
dp[node][0] = 1
N = int(input())
graph = [[] for _ in range(N+1)]
for _ in range(N-1):
x,y = map(int,input().split())
graph[x].append(y)
graph[y].append(x)
visited = [False]*(N+1)
dp = [[0]*2 for _ in range(N+1)]
dfs(1)
answer = min(list(map(sum,list(zip(*dp)))))
print(answer)
import sys
input = sys.stdin.readline
T = int(input())
def find_parents(node):
parent_list = [node]
while parent_num[node] != -1:
parent = parent_num[node]
parent_list.append(parent)
node = parent
return parent_list
for _ in range(T):
N = int(input())
parent_num = [-1]*(N+1) # 해당 index의 부모가 안에 들어가 있다.
for _ in range(N-1):
parent,child = map(int,input().split())
parent_num[child] = parent
num1,num2 = map(int,input().split())
parents1 = find_parents(num1)
parents2 = find_parents(num2)
if len(parents1) < len(parents2):
parents1,parents2 = parents2, parents1
result = -1
for num in parents1:
if num in parents2:
result = num
break
print(result)
import sys
sys.setrecursionlimit(100000)
def dfs(node):
if visited[node]:
return
visited[node] = True
child_nodes = []
for next_node in graph[node]:
if not visited[next_node]:
child_nodes.append(next_node)
if not len(child_nodes):
dp[node][0] = town_person[node]
return
for child_node in child_nodes:
dfs(child_node)
dp[node][0] += dp[child_node][1]
dp[node][1] += max(dp[child_node][0],dp[child_node][1])
dp[node][0] += town_person[node]
N = int(input())
town_person = [0] +list(map(int,input().split()))
graph = [[] for _ in range(N+1)]
for _ in range(N-1):
x,y = map(int,input().split())
graph[x].append(y)
graph[y].append(x)
visited = [False]*(N+1)
dp = [[0]*2 for _ in range(N+1)]
# 0번 인덱스 : 참여를 했을때
# 1번 인덱스 : 참여를 안 했을때
dfs(1)
print(max(dp[1]))
N = int(input())
arr = list(map(int,input().split()))
parent = {i:[] for i in range(N)}
child = {i:[] for i in range(N)}
root_node = -1
for ind in range(N):
if arr[ind] == -1:
root_node = ind
else:
parent[ind].append(arr[ind])
child[arr[ind]].append(ind)
remove_node = int(input())
if root_node != remove_node:
remove_nodes = set()
stack = [remove_node]
visited = [True] * N
while stack:
node = stack.pop(0)
remove_nodes.add(node)
for k in child[node]:
if visited[k]:
visited[k] = False
stack.append(k)
parent_node = parent[node][0]
child[parent_node].remove(node)
leef_nodes = set()
for ind in range(N):
if not len(child[ind]):
leef_nodes.add(ind)
print(len(leef_nodes-remove_nodes))
else:
print(0)

가장 자신없고, 공부를 많이 안한 Tree 분야라, 코드가 난잡하다. 기본적인 원리는 parent_node들과 child_node들을 정리해주고, root_node를 구해준다.

만약 지워야할 remove_node와 root_node가 같을시 leafnode가 없으니 0을 출력해준다.

 

그 외에는 다음과 같이 진행을 한다. 지워진 node들과 leafnode가 될수 있는 node들을 구한뒤 leafnode에서 remove_node를 차집합을 해줘서 그 길이를 출력해준다. 이게 가능했던 이유는 , 반복문을 돌면서 remove_node로 들어간 node들의 child_node를 빼줬기 때문이다.

 

내가 푼 풀이는 난잡하게 여러 변수들이 존재하고, 깔끔하지 못해서 잘 푸신 다른분의 코드를 클론코딩하면서 공부도 했다.

 

 

---- 클론 코딩 및 분석---

# nh7881님 코드 분석
def find_leafs(index,child_nodes):
# index가 -1이라는 것은 root_node가 remove_node인 경우이니, 그때에는 0을 반환을 해준다.
if index == -1:
return 0
# child_node의 길이가 0인것은 child_Node가 없는 것이므로, leaf_node이다 그러므로 1을 반환해준다.
if len(child_nodes[index]) == 0:
return 1
result = 0
# 현재 노드인 index의 child_node를 가져와서, 재귀를 실행시켜준다.
for child_node in child_nodes[index]:
result += find_leafs(child_node,child_nodes)
return result
N = int(input())
graph = list(map(int,input().split()))
# 최상위 노드를 찾아주기 위함이다. 초기값은 node에 존재하지 않는 값으로 해준다.
root_node = -1
remove_node = int(input())
child_nodes = {i:[] for i in range(N)}
for ind in range(N):
# 우리는 leaf_node를 찾을것이고, 해당 index에 부모 node로 들어온
# input값을 반대로 바꿔주는 과정이 필요하다.
# 만약에 지우는 node와 index가 같으면 굳이 parent을 찾아 child를 넣어주는 과정이 필요없다.
# 그래서 continue로 넘어가준다.
# 또한 유일하게 부모가 없는 root_node는 따로 구분을 해준다. if문의 순서가 remove_node가 먼저 앞으로
# 오는 이유는 remove_node가 root_node일수 있기 때문이다. 이럴때를 구분해주기 위해, remove_node인지 판별하는게 먼저온다.
# 그외에는 전부 parent_node를 기준으로 child를 추가해주는 방식으로 해준다.
if remove_node == ind:
continue
if graph[ind] == -1:
root_node = ind
continue
child_nodes[graph[ind]].append(ind)
# root_node를 기점으로 leaf_node를 찾는 재귀함수를 실행시켜준다.
print(find_leafs(root_node,child_nodes))

 

 

+ Recent posts