import sys

sys.setrecursionlimit(20000)

def solution(k, num, links):
    left = sum(num)//k
    right = sum(num)
    N = len(num)
    degrees = [0 for _ in range(N)]
    for ind in range(N):
        if links[ind][0] != -1:
            degrees[links[ind][0]] += 1
        if links[ind][1] != -1:
            degrees[links[ind][1]] += 1
    root = -1
    for ind in range(N):
        if not degrees[ind]:
            root = ind
            break
    INF = float('inf')
    def check(node):
        nonlocal links,INF,dp,mid
        if links[node][0] == -1 and links[node][1] == -1:
            if num[node]<=mid:
                dp[node][0] = num[node]
                dp[node][1] = 1
            else:
                dp[node][0] = 0
                dp[node][1] = INF
        else:
            for ind in range(2):
                if links[node][ind] != -1:
                    check(links[node][ind])
            left_node = links[node][0]
            right_node = links[node][1] 
            if dp[left_node][0] + dp[right_node][0] + num[node] <= mid:
                dp[node][0] = dp[left_node][0] + dp[right_node][0] + num[node]
                dp[node][1] = dp[left_node][1] + dp[right_node][1] -1
                return
            if right_node == -1:
                if num[node] <= mid:
                    dp[node][0] = num[node]
                    dp[node][1] = dp[left_node][1] + 1
                else:
                    dp[node][0] = 0
                    dp[node][1] = INF

            else:
                if num[node] + min(dp[right_node][0],dp[left_node][0])<= mid:
                    dp[node][0] = num[node] + min(dp[right_node][0],dp[left_node][0])
                    dp[node][1] = dp[right_node][1] + dp[left_node][1]
                elif num[node] <= mid:
                    dp[node][0] = num[node]
                    dp[node][1] = dp[left_node][1] + dp[right_node][1] + 1
                else:
                    dp[node][0] = 0
                    dp[node][1] = INF

            



    while left+1<right:
        mid = (left+right)//2
        dp = [[0,1] for _ in range(N+1)]
        check(root)
        if dp[root][1]<=k:
            right = mid
        else:
            left = mid
    return right

풀기 힘들었던 문제이다.

 

 

이 문제는 트리 DP와 파라메트릭 서치를 활용한 문제로

 

k개의 그룹으로 나눌 수 있는 최소 값을 찾는것이다.

 

카카오 해설에서도 알수 있듯이 이분탐색으로 찾으면 되고,

 

이 이분 탐색의 중간값인 mid가 의미하는 것은 이 트리를 나눌 최대의 인원을 의미한다.

 

한 그룹의 인원이 mid가 넘어가지 않게 해주면 된다.

 

즉 mid의 값이 크면 k는 적게 나올것이고, mid의 값이 작으면 k보다 많게 나올 것이다.

 

 

우리는 left+1<right일때 끝나고

 

left일때 false이므로,

 

우리가 구하고자 하는 값은 right임을 알 수 있다.

 

문제에 들어가면,

 

가장 밑 노드에서부터 해주면 된다.

 

먼저 dp 테이블은 N*2의 크기로 해주었고,

dp[x][0] 은 x 노드에서의 최소 그룹인원의 수

dp[x][1] 은 x 노드에서의 그룹의 개수를 의미한다.

 

각 노드는 처음에 하나의 그룹이므로, dp[x][0]은 0으로 dp[x][1]은 1로 초기화를 해주었다.

 

총 4가지 경우가 있을것이다.

 

1. 자식노드 2개에 누적된 그룹인원과 현재노드의 인원과  더해도 mid보다 작거나 같다.

    그러면 dp[parent][0] = num[parent] + dp[left_node][0] + dp[right_node][1]이 될것이다.

    그리고 그룹의 수는 자식노드의 그룹의 수를 더해준 것에서 -1을 해준다.

    이러는 이유는 우리가 각 노드를 하나의 그룹으로 생각했기 때문에, 처음에 1로 전부 초기화 되어 있어서 그런다.

   자식노드와 현재그룹을 더해서 한개의 그룹이 된 것이므로 - 1을 해준다.

 

2. 자식노드 둘 중 1개 그룹인원과  현재 노드의 인원과 더하면 mid보다 작거나 같다.

 

   이럴때에는 그룹인원이 많은쪽의 간선을 끊는것과 마찬가지 이므로

   dp[parent][0] = num[parent] + min(dp[left_node][0], dp[right_node][1])

  을 해주고, 그룹은 두 자식노드의 그룹인원을 더해주면 된다.

 

3. 현재 그룹인원의 mid보다 작거나 같고, 자식노드들의 인원을 더하면 mid보다 크다.

 

   이럴 경우는 간선을 전부 끝는 것이다.

   그래서 dp[parent][0] = num[parent]를 해주고,

   그룹의 수는 1개를 더 더해주면 된다.

 

4. 현재 그룹의 인원의 mid보다 크다.

   이럴때에는 어떤 방도로도 mid보다 작거나 같은 그룹으로 나눌수 없으므로, 그룹인원의 수를 0으로 초기화 해주고, 그룹의 수를 INF로 해준다.

 

 

이 문제는 평소 코에서 나오지 않는 트리dp 문제인데다가, 이분탐색까지 활용해야되서 어려웠던 문제였다.

 

dp 테이블을 설계하는 것부터, 그룹의수가 k이면서 최소의 그룹인원 수를 어떻게 찾아하는지 생각하기 어려운 문제였다.

 

이 문제는 카카오 시험이 끝난뒤, 알고리즘 고수분들을 이야기를 통해, 어떻게 풀어야할지 감을 잡았고,

 

문제가 나오고 실제로 푸는데에도 시간이 오래걸린 문제였다.

 

이 문제가 다시 나오면 풀수 있다는 확신은 없지만, 한번 경험해봤으니 이러한 문제도 있다라는 것을 배운점에 만족했다.

 

+ Recent posts