import sys

sys.setrecursionlimit(20000)

def solution(k, num, links):
    left = sum(num)//k
    right = sum(num)
    N = len(num)
    degrees = [0 for _ in range(N)]
    for ind in range(N):
        if links[ind][0] != -1:
            degrees[links[ind][0]] += 1
        if links[ind][1] != -1:
            degrees[links[ind][1]] += 1
    root = -1
    for ind in range(N):
        if not degrees[ind]:
            root = ind
            break
    INF = float('inf')
    def check(node):
        nonlocal links,INF,dp,mid
        if links[node][0] == -1 and links[node][1] == -1:
            if num[node]<=mid:
                dp[node][0] = num[node]
                dp[node][1] = 1
            else:
                dp[node][0] = 0
                dp[node][1] = INF
        else:
            for ind in range(2):
                if links[node][ind] != -1:
                    check(links[node][ind])
            left_node = links[node][0]
            right_node = links[node][1] 
            if dp[left_node][0] + dp[right_node][0] + num[node] <= mid:
                dp[node][0] = dp[left_node][0] + dp[right_node][0] + num[node]
                dp[node][1] = dp[left_node][1] + dp[right_node][1] -1
                return
            if right_node == -1:
                if num[node] <= mid:
                    dp[node][0] = num[node]
                    dp[node][1] = dp[left_node][1] + 1
                else:
                    dp[node][0] = 0
                    dp[node][1] = INF

            else:
                if num[node] + min(dp[right_node][0],dp[left_node][0])<= mid:
                    dp[node][0] = num[node] + min(dp[right_node][0],dp[left_node][0])
                    dp[node][1] = dp[right_node][1] + dp[left_node][1]
                elif num[node] <= mid:
                    dp[node][0] = num[node]
                    dp[node][1] = dp[left_node][1] + dp[right_node][1] + 1
                else:
                    dp[node][0] = 0
                    dp[node][1] = INF

            



    while left+1<right:
        mid = (left+right)//2
        dp = [[0,1] for _ in range(N+1)]
        check(root)
        if dp[root][1]<=k:
            right = mid
        else:
            left = mid
    return right

풀기 힘들었던 문제이다.

 

 

이 문제는 트리 DP와 파라메트릭 서치를 활용한 문제로

 

k개의 그룹으로 나눌 수 있는 최소 값을 찾는것이다.

 

카카오 해설에서도 알수 있듯이 이분탐색으로 찾으면 되고,

 

이 이분 탐색의 중간값인 mid가 의미하는 것은 이 트리를 나눌 최대의 인원을 의미한다.

 

한 그룹의 인원이 mid가 넘어가지 않게 해주면 된다.

 

즉 mid의 값이 크면 k는 적게 나올것이고, mid의 값이 작으면 k보다 많게 나올 것이다.

 

 

우리는 left+1<right일때 끝나고

 

left일때 false이므로,

 

우리가 구하고자 하는 값은 right임을 알 수 있다.

 

문제에 들어가면,

 

가장 밑 노드에서부터 해주면 된다.

 

먼저 dp 테이블은 N*2의 크기로 해주었고,

dp[x][0] 은 x 노드에서의 최소 그룹인원의 수

dp[x][1] 은 x 노드에서의 그룹의 개수를 의미한다.

 

각 노드는 처음에 하나의 그룹이므로, dp[x][0]은 0으로 dp[x][1]은 1로 초기화를 해주었다.

 

총 4가지 경우가 있을것이다.

 

1. 자식노드 2개에 누적된 그룹인원과 현재노드의 인원과  더해도 mid보다 작거나 같다.

    그러면 dp[parent][0] = num[parent] + dp[left_node][0] + dp[right_node][1]이 될것이다.

    그리고 그룹의 수는 자식노드의 그룹의 수를 더해준 것에서 -1을 해준다.

    이러는 이유는 우리가 각 노드를 하나의 그룹으로 생각했기 때문에, 처음에 1로 전부 초기화 되어 있어서 그런다.

   자식노드와 현재그룹을 더해서 한개의 그룹이 된 것이므로 - 1을 해준다.

 

2. 자식노드 둘 중 1개 그룹인원과  현재 노드의 인원과 더하면 mid보다 작거나 같다.

 

   이럴때에는 그룹인원이 많은쪽의 간선을 끊는것과 마찬가지 이므로

   dp[parent][0] = num[parent] + min(dp[left_node][0], dp[right_node][1])

  을 해주고, 그룹은 두 자식노드의 그룹인원을 더해주면 된다.

 

3. 현재 그룹인원의 mid보다 작거나 같고, 자식노드들의 인원을 더하면 mid보다 크다.

 

   이럴 경우는 간선을 전부 끝는 것이다.

   그래서 dp[parent][0] = num[parent]를 해주고,

   그룹의 수는 1개를 더 더해주면 된다.

 

4. 현재 그룹의 인원의 mid보다 크다.

   이럴때에는 어떤 방도로도 mid보다 작거나 같은 그룹으로 나눌수 없으므로, 그룹인원의 수를 0으로 초기화 해주고, 그룹의 수를 INF로 해준다.

 

 

이 문제는 평소 코에서 나오지 않는 트리dp 문제인데다가, 이분탐색까지 활용해야되서 어려웠던 문제였다.

 

dp 테이블을 설계하는 것부터, 그룹의수가 k이면서 최소의 그룹인원 수를 어떻게 찾아하는지 생각하기 어려운 문제였다.

 

이 문제는 카카오 시험이 끝난뒤, 알고리즘 고수분들을 이야기를 통해, 어떻게 풀어야할지 감을 잡았고,

 

문제가 나오고 실제로 푸는데에도 시간이 오래걸린 문제였다.

 

이 문제가 다시 나오면 풀수 있다는 확신은 없지만, 한번 경험해봤으니 이러한 문제도 있다라는 것을 배운점에 만족했다.

 

import sys
sys.setrecursionlimit(20000)
def input():
    return sys.stdin.readline().rstrip()
def trace(cur_node,prev_check):
    visited[cur_node] = False
    flag = 0
    if not prev_check:
        if dp[cur_node][0] < dp[cur_node][1]:
            result.append(cur_node)
            flag = 1
    for next_node in graph[cur_node]:
        if visited[next_node]:
            trace(next_node,flag)

def dfs(node):
    visited[node] = False
    child_node = []
    for next_node in graph[node]:
        if visited[next_node]:
            child_node.append(next_node)
    if not len(child_node):
        dp[node][1] = arr[node]

        return
    else:
        dp[node][1] += arr[node]
        for child in child_node:
            dfs(child)
            dp[node][0] += max(dp[child][1],dp[child][0])
            dp[node][1] += dp[child][0]



N = int(input())
arr = [0] + list(map(int,input().split()))

# 1이 선택 
# 0이 선택 x
dp = [[0 for _ in range(2)] for _ in range(N+1)]

graph = [[] for _ in range(N+1)]
visited = [True for _ in range(N+1)]
for k in range(N-1):
    x,y = map(int,input().split())
    graph[x].append(y)
    graph[y].append(x)


dfs(1)

print(max(dp[1]))


result = []
visited = [True for _ in range(N+1)]
trace(1,0)
result.sort()
print(*result)

 

 

어려웠던 문제였다. 트리dp를 활용해서 최적의 경우를 찾는 것은 많았지만, 그 최적을 만족하는 것을 trace를 하는 문제는 처음이라 푸는데 오래걸렸다.

 

dp라는 테이블을 만들고 (N+1)*2의 배열을 만들어 두었다.

 

여기서 1은 해당 노드를 선택했을때의 최대값

 

0은 해당 노드를 선택하지 않았을때의 최대값이다.

 

만약 리프노드이면 자신을 선택하는 조건이외에는 없을것이다.

 

그러므로 dp[leef_node][1] = arr[leef_node]를 넣어준다.

 

리프노드를 제외한 나머지 노드들은 0번 인덱스와 1번인덱스의 최대값을 구해야한다.

 

현재 노드를 선택했을때 즉 1번인덱스에는

 

현재 노드를 cur_node라고 했을때 arr[cur_node]를 더해줘야할것이다.

 

그리고 현재 노드를 선택했으므로, 자식노드들은 선택을 하지 못한다.

 

그러므로 dp[cur_node][1] += dp[child_node][0] 와 같이

 

자식 노드들의 선택하지않았을때의 값들을 전부 더해줘야한다.

 

그리고 현재 노드를 선택하지 않았을때에는

 

자식 노드를 선택하던지, 선택하지 않는것은 마음대로 할 수 있다.

 

그러므로 둘중 더 큰 값을 더해주면 된다.

 

이렇게 dp에 값을 누적시켜서 구하면

 

우리는 1번노드로 시작했기 때문에

 

1번노드의 dp에서 최대값을 구할 수 있다.

 

그러면 우리는 이 dp에서 어떤 노드들이 선택됬는지를 찾아봐야한다.

 

 

우리는 이전 노드에서 선택했는지 안했는지에 따라, 현재 노드를 선택할수 있는지 아닌지가 결정이 된다.

 

그래서 우리는 prev_check를 통해 문제를 해결했다. 위와 동일하게 하기 위해서 0일때에는 부모노드를 선택하지 않았다.

 

1일때에는 부모노드를 선택했다이다.

 

만약 부모노드를 선택하지 않았을때에는, 우리는 지금 있는 현재노드를 선택할지 안할지를 결정지을수 있다.

 

그 결정방법은 dp에 저장된 값을 비교를 해서, 우리가 선택하지 않았을때보다 선택했을때의 값이 더 크면,

 

그때 현재노드를 선택을 하고, flag를 1로 바꿔준다. 현재노드를 선택을 했으므로, result에 추가 시켜준다.

 

그리고 재귀를 통해 모든 노드들에 대해 탐색을 해주면 된다.

 

이 문제는 최대값을 구하는 것 뿐만 아니라, 그 경우의 수가 만들어지는 것을 찾는데 더 어려움을 겪었다.

 

트리 dp와 관련된 문제들을 좀 더 연습해야겠다.

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 1045 도로  (0) 2021.06.29
[BOJ/백준] 2250 트리의 높이와 너비  (0) 2021.06.22
[BOJ/백준] 1414 불우이웃 돕기  (0) 2021.06.22
[BOJ/백준] 1799 비숍  (0) 2021.06.18
[BOJ/백준] 2233 사과 나무  (0) 2021.06.14
import sys

sys.setrecursionlimit(100001)


def dfs(node):
    visited[node] = True
    child_node = []
    for next_node in graph[node]:
        if not visited[next_node]:
            child_node.append(next_node)
    if not len(child_node):
        dp[node][0] = 1
        return
    else:
        for child in child_node:
            dfs(child)

        for child in child_node:
            if dp[child][0]:
                dp[node][1] = 1
                break
        else:
            dp[node][0] = 1

N = int(input())
graph = [[] for _ in range(N+1)]


for _ in range(N-1):
    x,y = map(int,input().split())
    graph[x].append(y)
    graph[y].append(x)
visited = [False]*(N+1)
dp = [[0]*2 for _ in range(N+1)]
dfs(1)

answer = min(list(map(sum,list(zip(*dp)))))
print(answer)

+ Recent posts