import sys
import heapq
def input():
    return sys.stdin.readline().rstrip()



def find_parents(X):
    if X == make_set[X]:
        return X
    make_set[X] = find_parents(make_set[X])
    return make_set[X]


def union(x,y):
    X = find_parents(x)
    Y = find_parents(y)

    if X !=Y:
        if ranks[X]< ranks[Y]:
            X,Y = Y,X
        make_set[Y] = X
        if ranks[X] == ranks[Y]:
            ranks[X] += 1
        return True
    return False
N,M = map(int,input().split())

edge_list = []

for x in range(N):
    temp = list(input())
    for y in range(x+1,N):
        if temp[y] == 'Y':
            heapq.heappush(edge_list,(x,y))
city_cnt = 0
if len(edge_list) >= M:
    result = [0 for _ in range(N)]
    make_set = [i for i in range(N)]
    ranks = [1 for _ in range(N)]
    remain_list = []
    while edge_list:
        node_a,node_b = heapq.heappop(edge_list)
        if union(node_a,node_b):
            city_cnt += 1
            result[node_b] += 1
            result[node_a] += 1
        else:
            heapq.heappush(remain_list,(node_a,node_b))
    if city_cnt != N-1:
        print(-1)
    else:
        remain_cnt = M - city_cnt

        while remain_cnt>0:
            node_a,node_b = heapq.heappop(remain_list)
            result[node_a] += 1
            result[node_b] += 1
            remain_cnt -= 1
        print(*result)
else:
    print(-1)

이 문제는 문제를 이해하는데 힘들었고, 구현을 하는데 어려운 점이 많았었다.

 

평소에 MST 관련 문제는 프림 알고리즘을 주로 이용해서 풀었지만, 이 문제를 푸는데에는 크루스칼 알고리즘을 이용했다.

 

문제를 푸는 방식은 다음과 같다. 모든 간선들을 저장을 해주는데, 그 조건은 x,y 에서 y는 무조건 x보다 큰 경우의 간선만 저장을 시켜준다.

 

그래서 모든 간선을 저장시켰는데 주어진 M보다 적을때에는 -1을 출력을 해준다.

 

그리고 크루스칼 알고리즘을 이용해서 풀면된다.

 

그러나, 여기서 다른 점은 우리가 이미 union이 된 간선들을 버리게 되는데, 이걸 따로 저장을 시켜놓는다.

 

그리고 우리는 모든 크루스칼 알고리즘을 돌린뒤에, 모든 도시가 연결이 안되어있으면 -1을 출력을 해주고,

 

그리고 우리는 무조건 도로 M개를 무조건 연결을 해야하므로, 우리는 크루스칼 알고리즘을 통해 N-1개를 연결을 해놓은 상태이다.

 

그러므로 M-(N-1)개의 도로를 추가적으로 연결을 시켜주면 된다.

 

그래서 우리는 저장시켜놓은 도로들을 우선순위가 높은 것부터 뽑아서 연결을 시켜주면 된다.

 

 

import sys
from collections import deque
def input():
    return sys.stdin.readline().rstrip()



def find_parents(X):
    if X == make_set[X]:
        return X
    make_set[X] = find_parents(make_set[X])
    return make_set[X]


def union(x,y):
    X = find_parents(x)
    Y = find_parents(y)

    if X !=Y:
        if ranks[X]< ranks[Y]:
            X,Y = Y,X
        make_set[Y] = X
        if ranks[X] == ranks[Y]:
            ranks[X] += 1
        return True
    return False
N,M = map(int,input().split())

edge_list = deque()
result = [0 for _ in range(N)]
make_set = [i for i in range(N)]
ranks = [1 for _ in range(N)]

city_cnt = 0
for x in range(N):
    temp = list(input())
    for y in range(x+1,N):
        if temp[y] == 'Y':
            if union(x,y):
                city_cnt += 1
                result[x] += 1
                result[y] += 1
            else:
                edge_list.append((x,y))


if city_cnt<N-1 or city_cnt+len(edge_list)<M:
    print(-1)
else:
    remain_cnt = M - city_cnt

    while remain_cnt>0:
        x,y = edge_list.popleft()
        result[x] += 1
        result[y] += 1
        remain_cnt -= 1
    print(*result)

이건 코드를 좀 더 개선시킨 버전이다.

 

달라진것은 heapq 대신 그냥 앞에서부터 차근차근 해주는 방식으로 바꿔준 것이다.

 

이 문제는 크루스칼을 이용하는 것뿐만 아니라, 문제를 잘 분석해야 풀 수 있었떤 문제였다.

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 2463 비용  (0) 2021.06.29
[BOJ/백준] 1398 동전  (0) 2021.06.29
[BOJ/백준] 2250 트리의 높이와 너비  (0) 2021.06.22
[BOJ/백준] 2213 트리의 독립집합  (0) 2021.06.22
[BOJ/백준] 1414 불우이웃 돕기  (0) 2021.06.22
import sys
from collections import defaultdict
def input():
    return sys.stdin.readline().rstrip()


def dfs(cur_idx,parent_idx):
    global cnt
    depth[cur_idx] = depth[parent_idx] + 1


    if tree[cur_idx][0] != -1:
        dfs(tree[cur_idx][0],cur_idx)
    cnt += 1
    width[depth[cur_idx]].append(cnt) 
    if tree[cur_idx][1] != -1:
        dfs(tree[cur_idx][1],cur_idx)

N = int(input())
root_check = [0 for _ in range(N+1)]
tree = [[-1 -1] for _ in range(N+1)]
depth = [0 for _ in range(N+1)]
width = defaultdict(list)
for _ in range(N):
    x,left_node,right_node = map(int,input().split())
    tree[x] = [left_node,right_node]
    if left_node != -1:
        root_check[left_node] += 1
    if right_node != -1:
        root_check[right_node] += 1
root_num = -1
for k in range(1,N+1):
    if root_check[k] == 0:
        root_num = k
        break


cnt = 0
dfs(root_num,0)

max_value = 0
max_depth = -1
for d in range(1,max(depth)+1):
    max_width,min_width = max(width[d]),min(width[d])
    if max_value < max_width - min_width + 1:
        max_value = max_width - min_width + 1
        max_depth = d

print(max_depth,max_value)

 

 

이 문제를 풀때 주의할점은 2가지이다.

 

루트노드가 꼭 1인것은 아니다.

 

들어오는 순서가 1,2,3,4,5,6,7,8,....N이 아닐수도 있다.

 

이 2가지를 주의하면 풀면 되는 문제이다.

 

 

먼저 트리를 위한 (N+1)*2의 배열을 만들어뒀다.

 

각 행의 0번인덱스는 왼쪽 자식, 1번인덱스는 오른쪽자식을 의미한다.

 

이 문제는 중위순회를 구현을 해서, 왼쪽 자식->루트->오른쪽자식 순서로 탐색을 해준다.

 

그리고 루트일때 현재의 LEVEL에 width를 저장을 해주엇다.

 

cnt를 0부터 시작했기때문에, append 하기 직전에 +=1 을 해준것이다.

 

만약 1부터 하신분들은 순서를 바꾸면될것이다.

 

그리고 전체 레벨 만큼 탐색을하면서 최대 너비를 구해주면 되는 문제이다.

 

좀 더 깔끔히 하신분들은 level별로 2차원배열을 만든뒤에 둘 중 하나는 최대값 하나는 최소값을 저장을 시켜서 

 

마지막에 바로 계산이 가능하게 한 분들도 계신다.

 

 

 

 

 

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 1398 동전  (0) 2021.06.29
[BOJ/백준] 1045 도로  (0) 2021.06.29
[BOJ/백준] 2213 트리의 독립집합  (0) 2021.06.22
[BOJ/백준] 1414 불우이웃 돕기  (0) 2021.06.22
[BOJ/백준] 1799 비숍  (0) 2021.06.18
import sys
sys.setrecursionlimit(20000)
def input():
    return sys.stdin.readline().rstrip()
def trace(cur_node,prev_check):
    visited[cur_node] = False
    flag = 0
    if not prev_check:
        if dp[cur_node][0] < dp[cur_node][1]:
            result.append(cur_node)
            flag = 1
    for next_node in graph[cur_node]:
        if visited[next_node]:
            trace(next_node,flag)

def dfs(node):
    visited[node] = False
    child_node = []
    for next_node in graph[node]:
        if visited[next_node]:
            child_node.append(next_node)
    if not len(child_node):
        dp[node][1] = arr[node]

        return
    else:
        dp[node][1] += arr[node]
        for child in child_node:
            dfs(child)
            dp[node][0] += max(dp[child][1],dp[child][0])
            dp[node][1] += dp[child][0]



N = int(input())
arr = [0] + list(map(int,input().split()))

# 1이 선택 
# 0이 선택 x
dp = [[0 for _ in range(2)] for _ in range(N+1)]

graph = [[] for _ in range(N+1)]
visited = [True for _ in range(N+1)]
for k in range(N-1):
    x,y = map(int,input().split())
    graph[x].append(y)
    graph[y].append(x)


dfs(1)

print(max(dp[1]))


result = []
visited = [True for _ in range(N+1)]
trace(1,0)
result.sort()
print(*result)

 

 

어려웠던 문제였다. 트리dp를 활용해서 최적의 경우를 찾는 것은 많았지만, 그 최적을 만족하는 것을 trace를 하는 문제는 처음이라 푸는데 오래걸렸다.

 

dp라는 테이블을 만들고 (N+1)*2의 배열을 만들어 두었다.

 

여기서 1은 해당 노드를 선택했을때의 최대값

 

0은 해당 노드를 선택하지 않았을때의 최대값이다.

 

만약 리프노드이면 자신을 선택하는 조건이외에는 없을것이다.

 

그러므로 dp[leef_node][1] = arr[leef_node]를 넣어준다.

 

리프노드를 제외한 나머지 노드들은 0번 인덱스와 1번인덱스의 최대값을 구해야한다.

 

현재 노드를 선택했을때 즉 1번인덱스에는

 

현재 노드를 cur_node라고 했을때 arr[cur_node]를 더해줘야할것이다.

 

그리고 현재 노드를 선택했으므로, 자식노드들은 선택을 하지 못한다.

 

그러므로 dp[cur_node][1] += dp[child_node][0] 와 같이

 

자식 노드들의 선택하지않았을때의 값들을 전부 더해줘야한다.

 

그리고 현재 노드를 선택하지 않았을때에는

 

자식 노드를 선택하던지, 선택하지 않는것은 마음대로 할 수 있다.

 

그러므로 둘중 더 큰 값을 더해주면 된다.

 

이렇게 dp에 값을 누적시켜서 구하면

 

우리는 1번노드로 시작했기 때문에

 

1번노드의 dp에서 최대값을 구할 수 있다.

 

그러면 우리는 이 dp에서 어떤 노드들이 선택됬는지를 찾아봐야한다.

 

 

우리는 이전 노드에서 선택했는지 안했는지에 따라, 현재 노드를 선택할수 있는지 아닌지가 결정이 된다.

 

그래서 우리는 prev_check를 통해 문제를 해결했다. 위와 동일하게 하기 위해서 0일때에는 부모노드를 선택하지 않았다.

 

1일때에는 부모노드를 선택했다이다.

 

만약 부모노드를 선택하지 않았을때에는, 우리는 지금 있는 현재노드를 선택할지 안할지를 결정지을수 있다.

 

그 결정방법은 dp에 저장된 값을 비교를 해서, 우리가 선택하지 않았을때보다 선택했을때의 값이 더 크면,

 

그때 현재노드를 선택을 하고, flag를 1로 바꿔준다. 현재노드를 선택을 했으므로, result에 추가 시켜준다.

 

그리고 재귀를 통해 모든 노드들에 대해 탐색을 해주면 된다.

 

이 문제는 최대값을 구하는 것 뿐만 아니라, 그 경우의 수가 만들어지는 것을 찾는데 더 어려움을 겪었다.

 

트리 dp와 관련된 문제들을 좀 더 연습해야겠다.

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 1045 도로  (0) 2021.06.29
[BOJ/백준] 2250 트리의 높이와 너비  (0) 2021.06.22
[BOJ/백준] 1414 불우이웃 돕기  (0) 2021.06.22
[BOJ/백준] 1799 비숍  (0) 2021.06.18
[BOJ/백준] 2233 사과 나무  (0) 2021.06.14
import sys
import heapq
def convert_Numner(x):
    global total_sum,INF
    if x.islower():
        num = ord(x) - ord('a') + 1
        total_sum += num
    elif x.isupper():
        num = ord(x) - ord('A')+ 27
        total_sum += num
    else:
        num = INF
        total_sum += 0
    return num


def input():
    return sys.stdin.readline().rstrip()


N = int(input())


total_sum = 0
INF = float('inf')
arr = [list(map(convert_Numner,list(input()))) for _ in range(N)]


cnt = 0
mst_dis = 0
distance_list = [INF for _ in range(N)]
visited = [False for _ in range(N)]
node_list = []

heapq.heappush(node_list,(0,0))
distance_list[0] = 0
while node_list:
    cur_dis,cur_node = heapq.heappop(node_list)
    if visited[cur_node]:continue
    cnt += 1
    mst_dis += cur_dis
    visited[cur_node] = True
    for next_node in range(N):
        if visited[next_node]:continue
        min_value = min(arr[cur_node][next_node],arr[next_node][cur_node])
        if distance_list[next_node] > min_value:
            distance_list[next_node] = min_value
            heapq.heappush(node_list,(min_value,next_node))

if cnt == N:
    print(total_sum-mst_dis)
else:
    print(-1)


 

 

처음에 프림알고리즘으로 풀었다. MST에서 익숙한 알고리즘이 프림알고리즘이라, 우선적으로 푼 방식이다.

 

여기서 여러번 틀렸었는데 그 이유는 다음과 같다. 우리가 연결을 하는 통로는 무방향이다.

 

즉 1번방에서 2번방으로 연결하는 랜선과 2번방에서 1번방으로 연결하는 랜선은 둘중 하나만 선택을 하면, 둘이 연결이 된것으로 본다.

 

이 문제를 풀때 무조건 0번방부터 시작을 했기때문에

 

3

zzz

abc

abc

 

와 같은 입력이 들어왔을때, 오히려 더 짧은 반대편방에서 오는 경우를 체크하지 못했다.

 

이러한 문제점을 해결하기 위해, 거리를 비교할때, 양방향의 최소값으로 비교를 하고, 그값을 넣어주는 방식으로 했다.

 

이런식이 아니라, 처음부터 graph라는 딕셔너리를 만들어서 최소값을 저장해주어도 된다.

 

이 점만 주의하면, 나머지는 일반적인 프림알고리즘을 이용한 MST 문제이다.

 

cnt개수가 N보다 부족할시에는 모든 노드들이 연결이 된것이 아니므로, -1을 출력해주면 된다.

 

 

 

import sys
def convert_Numner(x):
    global total_sum,INF
    if x.islower():
        num = ord(x) - ord('a') + 1
        total_sum += num
    elif x.isupper():
        num = ord(x) - ord('A')+ 27
        total_sum += num
    else:
        num = INF
        total_sum += 0
    return num


def find_parents(X):
    if X == make_set[X]:
        return X
    make_set[X] = find_parents(make_set[X])
    return make_set[X]


def union(x,y):
    X = find_parents(x)
    Y = find_parents(y)

    if X !=Y:
        if ranks[X]< ranks[Y]:
            X,Y = Y,X
        make_set[Y] = X
        if ranks[X] == ranks[Y]:
            ranks[X] += 1
        return True
    return False

def input():
    return sys.stdin.readline().rstrip()

N = int(input())


total_sum = 0
INF = float('inf')


edge_list = []
for x in range(N):
    input_string = input()
    temp = []
    for y in range(N):
        num = convert_Numner(input_string[y])
        if x == y:continue
        if num == INF:continue
        edge_list.append((num,x,y))
edge_list.sort(reverse=True)
cnt = 1
result = 0
make_set = [i for i in range(N)]
ranks = [1 for _ in range(N)]
while cnt <N and edge_list:
    pay,node_a,node_b = edge_list.pop()
    if union(node_a,node_b):
        cnt += 1
        result += pay


if cnt == N:
    print(total_sum-result)
else:
    print(-1)

 

크루스칼을 이용한 풀이이다. 푸는 방법은 위와 동일하다 대신 edge_list에 x와 y가 같은경우와 연결이 되지않는 경우는 제외시켰다.

 

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 2250 트리의 높이와 너비  (0) 2021.06.22
[BOJ/백준] 2213 트리의 독립집합  (0) 2021.06.22
[BOJ/백준] 1799 비숍  (0) 2021.06.18
[BOJ/백준] 2233 사과 나무  (0) 2021.06.14
[BOJ/백준] 1050 물약  (0) 2021.06.14
import sys

def input():
    return sys.stdin.readline().rstrip()

def check_pick(pick_list,point):
    x,y = point
    for nx,ny in pick_list:
        if abs(nx-x) == abs(ny-y):
            return False
    return True

def dfs(problem_list,start,pick_list,color):
    global result
    if len(problem_list) == start:
        result[color] = max(result[color],len(pick_list))
        return
    else:

        for idx in range(start,len(problem_list)):
            x,y = problem_list[idx]
            if check_pick(pick_list,(x,y)):
                dfs(problem_list,idx+1,pick_list+[(x,y)],color)


N = int(input())


arr = [list(map(int,input().split())) for _ in range(N)]

result = 0
black_set = []
white_set = []
for x in range(N):
    for y in range(N):
        if arr[x][y]:
            if (x+y)%2:
                white_set.append((x,y))
            else:
                black_set.append((x,y))

result = [0,0]
dfs(black_set,0,[],0)
dfs(white_set,0,[],1)

print(sum(result))

문제를 푼 아이디어는 다음과 같다. 체스판은 흰색판과 검은색판으로 나뉘어져있다. 그리고 비숍이 잡아 먹을수 있는 위치는 같은 색깔에 있는 위치의 비숍일 뿐이다. 그러므로, 흰색과 검은색으로, 좌표를 나눠서, 각각 놓을 수 있는 최대 위치를 찾아주면 된다.

 

# chw0501님 코드 복기

import sys

def input():
    return sys.stdin.readline()

def dfs(left_slash_num):
    if visited[left_slash_num]:
        return 0 
    visited[left_slash_num] = True

    for right_slash_num in slash_dict[left_slash_num]:
        if right_slash[right_slash_num] == -1 or dfs(right_slash[right_slash_num]):
            left_slash[left_slash_num] = right_slash_num
            right_slash[right_slash_num] = left_slash_num
            return 1
    return 0

N = int(input())
arr = [list(map(int,input().split())) for _ in range(N)]
right_slash = [-1 for _ in range(2*N+1)]
left_slash = [-1 for _ in range(2*N+1)]
slash_dict = [[] for _ in range(2*N+1)]
for x in range(N):
    for y in range(N):
        if arr[x][y]:
            slash_dict[x+y].append(x-y+N)


result = 0
for left_slash_num in range(2*N):
    visited = [False for _ in range(2*N)]
    if dfs(left_slash_num):
        result += 1
print(result)

 

더 빨리 풀 수 있는 코드는 대각선으로 나눠서 하는 방법인데 거기서 깔끔한 코드였던 chw0501님의 코드를 복기한 코드이다.

 

이 문제를 푸는방법은 다음과 같다. 한 위치에 비숍이 위치하면, 이 비숍의 X자 대각선으로는 다른 비숍들을 둘수 없다.

 

그렇다면, 흰색, 검은색으로만 나누지 말고, 대각선 위치를 기준으로 나눠 보는건 어떨까 라는 아이디어에서 출발된 것이다.

 

대각선을 (2*N-1)개씩 두 종류로 나뉠수 있다.

 

 

먼저 right_slash라고 부를, 오른쪽 아래로 향하는 대각선 모양은 다음과 같이 총 2*N-1개가 있다.

 

그러면 같은 대각선에 있는지 확인하는 방법은 행을 X 열을 Y로 했을때 (X-Y+N)으로 표현할 수 있다.

 

 

그리고 이렇게 좌측 아래로 향하는 슬래쉬를 left_slash라고 하겠다 이것도 위와 동일하게 2*N-1개가 생성되면

 

같은 집단인지 판별하는 방법은 (X+Y)이다.

 

그러면 우리는 해당 칸이 놓을수 있는 자리이면, left_slash를 구하고, 그 left_slash내에 해당 좌표의 right_slash 정보를 같이 집어넣어준다.

 

즉 slash_dict의 key 값은 left_slash의 좌표값이 되고, right_slash가 value가 되도록 넣어주면 된다.

 

이렇게 사전작업을 해놓은 상태에서 우리는 2개의 slash_list를 만들어놓는다.

 

각각 left_slash 와 right_slash로

 

해당의 index는 slash의 몇번째 위치에 있는지를 나타내며, 그 안의 value 값은 서로 반대편의 슬래쉬의 위치를 표시를 해주는 용도이다.

 

즉 left_slash의 3번째 위치에 5라는 값이 있으면,

 

3번째 left_slash에 있는 어떠한 점 중 하나를 골랐고, 그 점은 right_slash를 5를가지고 있다.를 의미한다.

 

즉 특정한 점을 만들 수 있는 것이다.

 

그러면 우리는 left_slah의 좌측아래로 향하는 슬래쉬를 기준으로 dfs를 진행을 해주면 된다.

 

def dfs(left_slash_num):
    if visited[left_slash_num]:
        return 0 
    visited[left_slash_num] = True

    for right_slash_num in slash_dict[left_slash_num]:
        if right_slash[right_slash_num] == -1 or dfs(right_slash[right_slash_num]):
            left_slash[left_slash_num] = right_slash_num
            right_slash[right_slash_num] = left_slash_num
            return 1
    return 0

 

이 dfs가 중요한데 이 dfs에서는 기본적으로 다음과 같다.

 

우리가 입력받은 left_slash_num에 저장된 right_slash_num들을 하나씩 가져온다.

 

right_slash라는 list에서 초기값으로 있으면 그 자리에 비숍을 놓을 수 있는 것이니,

 

left_slash에는 right_slash_num을 저장해주고, right_slash에는 left_slash_num을 저장을 해준다.

 

그러나 만약 해당 right_slash에 특정한 left_slash_num이 저장이 되어있다면,

 

우리는 위치를 조정해서 다시 놓을 수 있는지 확인을 해주는 것이다.

 

우리가 A_right라는 right_slash를 찾아볼려고 했더니 right_slash[A_right]의 값이 -1이 아니였고, B_left라는 값이 있었으면,

 

우리는 DFS를 통해 B_left를 다시 들어가서 B_left안에 있는 다른 right_slash들을 검사를 해서 놓을 수 있는지 확인을 해주는 것이다.

 

B_left 안에 [A_right,B_right,C_right] 등이 있다고하면,

 

이 중에 아직 right_slash가 -1인 값이 있으면, 우리는 그 위치를 조정해서 다시 비숍을 놓을 수 있을것이다.

 

그러나 이러한 경우가 하나도 없다면 비숍을 놓을 수 없기 때문에 0을 반환을 해주고,

 

한번 방문했던 곳을 다시 방문을 했다는 것은 사이클이 발생하고, 더 비숍을 놓을 곳이 없는것을 의미하기 때문에 0을 반환을 해주면된다.

 

위와 같은 dfs를 실행을 하면서 1을 반환이 되면, result의 값을 1씩 늘려주고,

 

최종적으로 result를 출력을 해주면 된다.

 

말로서 설명하기에는 어렵지만 디버깅도구를 통해 따라가보면 금방 이해할 수 있었다.

 

두번째 풀이는 이해하기 어려웠지만, 다른 사람의 코드를 보고 이런 방식으로도 할 수 있는걸 알았다.

 

좀 더 효율적인 방법이 있다는 걸 알고, 다음번에는 이런 코드를 짤 수 있도록 노력해야겠다.

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 2213 트리의 독립집합  (0) 2021.06.22
[BOJ/백준] 1414 불우이웃 돕기  (0) 2021.06.22
[BOJ/백준] 2233 사과 나무  (0) 2021.06.14
[BOJ/백준] 1050 물약  (0) 2021.06.14
[BOJ/백준] 15653 구슬 탈출 4  (0) 2021.06.14
import sys
sys.setrecursionlimit(10000)
def input():
    return sys.stdin.readline().rstrip()


def LCS(X,Y):
    if depth[X] != depth[Y]:
        if depth[X]>depth[Y]:
            X,Y = Y,X
        for i in range(MAX_NODE-1,-1,-1):
            if (depth[Y]-depth[X] >= (1<<i)):
                Y = parents[Y][i]
    if X == Y:
        return X
    if (Y != X):
        for i in range(MAX_NODE-1,-1,-1):
            if parents[X][i] != parents[Y][i]:
                X = parents[X][i]
                Y = parents[Y][i]
    return parents[X][0]
def tree_make(idx,cur_node,node_cnt,parent_idx):
    if idx == len(binary_list):
        return
    if binary_list[idx] == 0:
        node_cnt += 1
        cur_node = node_cnt
        binary_search[idx+1] = cur_node
        tree[cur_node].visited.append(idx+1)
        tree[cur_node].parent = parent_idx
        depth[cur_node] = depth[parent_idx] + 1
        parents[cur_node][0] = parent_idx
        if cur_node != 1:
            parent_node = tree[cur_node].parent
            if tree[parent_node].left == None:
                tree[parent_node].left = cur_node
            else:
                tree[parent_node].right = cur_node
        tree_make(idx+1,cur_node,node_cnt,cur_node)
    else:
        binary_search[idx+1] = cur_node
        tree[cur_node].visited.append(idx+1)
        cur_node = tree[cur_node].parent
        tree_make(idx+1,cur_node,node_cnt,cur_node)



class Node():
    def __init__(self,data):
        self.data = data
        self.left = None
        self.right = None
        self.parent = None
        self.visited = []

N = int(input())
binary_list = list(map(int,list(input())))
i,j = map(int,input().split())
tree = [Node(i) for i in range(N+1)]
MAX_NODE = 15
binary_search = [0 for _ in range(len(binary_list)+1)]
depth = [0 for _ in range(N+1)]
parents = [[0 for _ in range(MAX_NODE)] for _ in range(N+1)]

tree_make(0,0,0,0)
for par_idx in range(1,MAX_NODE):
    for k in range(1,N+1):
        parents[k][par_idx] = parents[parents[k][par_idx-1]][par_idx-1]

X,Y = binary_search[i],binary_search[j]

result_node = LCS(X,Y)
print(*tree[result_node].visited)

 

 

LCA를 이용해서 풀어본 처음 문제였다.

 

LCA를 공부하면서 푼 풀이이다. 

 

LCA에 대한 자세한 설명은

 

https://jason9319.tistory.com/90    https://www.crocus.co.kr/660    https://m.blog.naver.com/kks227/220820773477  https://velog.io/@lre12/LCA-%EC%95%8C%EA%B3%A0%EB%A6%AC%EC%A6%98Lowest-Common-Ancestor

 

LCA(Lowest Common Ancestor)

1개 이상의 노드로 구성 된 사이클이 없는 그래프를 트리라고 합니다. 우리는 트리에서 정의되는 LCA(Lowest Common Ancestor)에 대해서 얘기해 보려 합니다. LCA를 직역하면 최소 공통 조상(?) 정도의 뜻

jason9319.tistory.com

위의 사이트들을 참조하길 바란다.

 

 

LCA를 공부하면서, 이해하다가 막혔다가 푼 부분을 적으면 다음과 같다.

 

 

1. 부모를 저장하는 dp 테이블에는 해당 노드의 모든 부모를 저장하는 것이 아닌 2^k번째 부모들만 압축해서 저장해놓는것이다.

 

2. 위의 파생이지만, 그렇기 때문에, 두 높이의 차이가 1<<i 즉 2^i 차이가 나면, i만큼 올려주면 되는것이다.

 

3. 그리고 두 높이가 같고, 서로 다르다면, k번째 위치부터 다른 값까지 한꺼번에 올려버린다.

  즉, dp[node1][k] , dp[node2][k] 는 다르지만, dp[node1][k+1], dp[node][k+1]이 같으면

   각 노드들의 2^k + 1 ~ 2^(k+1) 사이에 최소 공통 조상이 있을 것이다. 이 부분을 주의하면 된다.

 

4. 또한, 첫번째에서 말한것처럼 2^k번째 만 저장을 시켜놓은것이기때문에

 

    2^(k+1) = 2^k + 2^k 이다.

    2^k = 2^(k-1) + 2^(k-1) 이므로,

dp[node][k] = dp[ dp[node][k-1] ] [k-1] 으로 하면 LCA의 해당 노드들의 모든 공통조상부모에 대해서 저장시킬수 있다.

 

위의 것들이 윗 링크들을 공부하면서 제가 깨달은 부분이다.

 

인제는 문제로 돌아가서 

 

여기서 문제에 주어진 비트를 트리구조로만 바꾸면 우리는 위에서 배운 LCA를 적용시키면 된다.

 

저는 여기서 이 문제를 해결하기 위해 Node라는 클래스를 만들어놨습니다.

 

해당 Node에는 총 5가지의 파라미터가 있습니다.

 

data는 이 노드의 번호를 나타내고,

left는 왼쪽 자식번호

right는 오른쪽 자식번호

parent는 부모 번호

visited는 우리가 주어진 비트에서 몇번재 비트에서 이 노드를 방문했는지 표시르 위한 것입니다.

 

저는 tree_make라는 재귀함수를 만들어 이 문제를 해결했습니다.

 

입력 parameter는 4가지가 있고, idx는 현재 입력으로 주어진 비트의 몇번째 idx인지를 나타냅니다.

 

cur_node는 현재 노드의 번호입니다.

 

node_cnt는 지금까지 생성된 노드의 개수입니다.

 

parent_idx는 부모 노드의 번호입니다.

 

비트에서 우리가 0을 만나면 최초로 생성되는 노드로 진입하게 되는겁니다.

 

그래서 node_cnt를 1을 늘려주고, cur_node를 node_cnt로 바꿔줍니다.

 

그러면 우리는 현재 부모노드의 정보를 알고 있으므로,

 

부모노드의 left와 right 중 빈 순서대로 넣어주면 됩니다.

 

그리고 1을 만나게 되면 우리는 현재 방문하고 있던 노드에서 다시 되돌아가서 부모노드로 가야합니다.

 

그렇기 때문에 cur_node를 부모노드로 치환시키고, 재귀를 반복하면 됩니다.

 

즉 0을 만나면 새로운 노드가 생성이되고,

 

1을 만나면 부모노드로 되돌아간다는 점만 기억을 하면, 문제에서 주어진 비트를 트리로 만들 수 있습니다.

 

코드가 깔끔하지 못해, 다른 분들의 코드를 보는 것을 더 추천합니다.

 

위와 같은 작업을 해서 비트를 트리구조로 바꾼뒤에 LCA을 적용시키면 이 문제를 해결 할 수 있습니다.

 

LCA와 관련된 문제를 처음 풀어보았기에, 헷갈리던 점도 많았고, 트리에 대해 숙달되지 못했기에,

 

비트를 트리로 구현하는데에 오래 걸린 문제였습니다.

 

매번 트리와 관련된 문제가 나오면 문제를 해결하는데 시간이 오래걸리는 것 같은데,

 

이 부분을 숙달하는데 더 노력을 해야할 것 같습니다.

 

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 1414 불우이웃 돕기  (0) 2021.06.22
[BOJ/백준] 1799 비숍  (0) 2021.06.18
[BOJ/백준] 1050 물약  (0) 2021.06.14
[BOJ/백준] 15653 구슬 탈출 4  (0) 2021.06.14
[BOJ/백준] 21944 문제 추천 시스템 Version2  (0) 2021.06.11
import sys
from collections import defaultdict,deque
def input():
    return sys.stdin.readline().rstrip()
N,M = map(int,input().split())
material_dict = {'-1':True}
# dict를 고정시키는게 문제
material_pay = [float('inf') for _ in range(10000)]
queue = deque()
for k in range(N):
    name,pay = input().split()
    material_dict[name] = k+1
    material_pay[k+1] = int(pay)
    queue.append(material_dict[name])



graph = defaultdict(list)


total_recipe_list = [input() for _ in range(M)]
recipe_cnt = defaultdict(int)
recipe_dict = defaultdict(list)
recipe_order = defaultdict(list)
total_recipe_list.sort()

for total_recipe in total_recipe_list:
    complete_name,arg = total_recipe.split('=')
    recipe = arg.split('+')
    if material_dict.get(complete_name):
        complete_idx = material_dict[complete_name]
    else:
        complete_idx = len(material_dict.keys())
        material_dict[complete_name] = complete_idx

    recipe_idx = (complete_idx,recipe_cnt[complete_idx])
    recipe_cnt[complete_idx] += 1
    temp_set = set()
    temp = defaultdict(int)
    for res in recipe:
        cnt,name = int(res[0]),res[1:]
        if material_dict.get(name):
            name_idx = material_dict[name]
        else:
            name_idx = len(material_dict.keys())
            material_dict[name] = name_idx
        if recipe_idx not in graph[name_idx]:
            graph[name_idx].append(recipe_idx)
        temp_set.add(name_idx)
        temp[name_idx] += cnt
    recipe_order[complete_idx].append(temp_set)
    recipe_dict[complete_idx].append(temp)

flag = True
        
total_len = len(material_dict) - 1
cnt = 0
while queue:

    name_idx = queue.popleft()
    if graph.get(name_idx):
        for next_idx,recipe_idx in graph[name_idx]:
            # 다시 들어왔을때 문제
            if name_idx in recipe_order[next_idx][recipe_idx]:
                recipe_order[next_idx][recipe_idx].remove(name_idx)


            if not len(recipe_order[next_idx][recipe_idx]):
                temp = 0

                for mat,val in recipe_dict[next_idx][recipe_idx].items():
                    temp += material_pay[mat]*val
                if material_pay[next_idx] > temp:

                    queue.append(next_idx)
                    material_pay[next_idx] = temp
                
if material_dict.get('LOVE'):
    result = material_pay[material_dict['LOVE']]
    if result == float('inf'):
        print(-1)
    elif result > 1000000000:
        print(1000000001)
    else:
        print(result)
else:
    print(-1)

이 문제는 여러번 틀렸던 문제였다.

 

풀면서 주의해야할 점은 다음과 같다.

 

처음에 주어지는 재료의 개수가 N의 최대가 50이라고 했지

 

 

모든 레시피에서 주어지는 재료의 개수가 50인게 아니다.

 

그래서 이걸 저장을 할때에는 넉넉하게해줘야한다.

 

 

두번째 여기서는 사이클같은 구조가 생길수 있다.

 

즉 위상정렬에서 한번 검사했던 노드이지만,

 

들어간 재료 중 하나가 최저값이 갱신이 되서, 현재 만들어지는 재료가 더 줄어둘수도 있다.

 

세번째 들어오는 입력에서 같은 재료가 여러번에 나뉘어서 들어 올수 있다.

 

네번째 같은 재료를 만드는 레시피가 여러종류가 있을 수 있다.

 

 

그래서 이 문제를 풀때 기본 컨셉은 다음과 같다.

 

각 완성재료를 만드는 전체 레시피들은 recipe_dict에 넣어줬다.

 

같은 완성재료를 만드는 방법이 여러개라면, recipe_cnt로 각 완성재료를 만드는 n번째 idx 이런식으로 구분을 해주었다.

 

그리고 위상정렬을 위해 order를 저장해주는 방식은

 

recipe_order[완성재료][레시피idx] 에 들어간 재료를 set으로 저장을 해주었다.

 

그리고 가장 중요한 그래프는 다음과 같이 저장을 해주었다.

 

graph라는 dictionary에 한 완성재료를 만드는 부품에 각각 (완성재료,완성재료를 만드는 레시피의 번호)를 저장시켜주었다.

 

그리고 material_dict은 재료명 대신 숫자로 관리하기 위해 만들어놨다.

 

material_pay는 각 재료를 만드는데 최소비용을 저장시켜놓은 리스트이다.

 

그러면 위상정렬을 어떻게 시작하면 되면

 

최초에 주어진 재료 N개를 queue에 넣어준다.

 

그리고 그 queue에서 하나씩 꺼내서

 

graph[name_idx]를 통해

 

이 재료로 만들 수 있는 레시피들에서 이 재료를 삭제시켜주었다.

 

이때 나중에 또 방문할 수 있으니, 있는지 확인을 하고 제거를 해준다.

 

 

그리고 이 recipe_order에 있는 set이 전부 사라져서 길이가 0 이면,

 

그때 queue에 들어갈수 있는지 확인해준다.

 

우리가 저장해놓은 recipe_dict을 이용해서 현재 재료들의 최저가로 만들었을때 나오는 비용을 계산한다.

 

비용을 계산한 후,  material_pay에 있는 값과 비교를 해서

 

최저가가 갱신이 되면 queue에 넣어준다.

 

그리고 이 반복문은 queue가 빌때까지 무한 반복해준다.

 

 

그러면 우리는 결과를 출력할때 두가지로 먼저 나눈다.

 

레시피와 재료에 아예 LOVE가 없을때, 그때는 -1을 출력해준다.

 

레시피와 재료에 있지만, 초기값인 float('inf')일때에는 -1을 출력해준다.

 

왜냐하면 어디선가 재료가 부족해서 LOVE에 아예 접근을 못한상태이기 때문이다.

 

그리고 1000만이 넘었을때에는 1000만1 그 외에는 재료값을 출력하게 해준다.

 

 

푼 사람들 코드 중에 제가 푼 코드보다 가독성이 좋고, 좋은 코드들이 있다.

 

가독성을 늘리는 방법은 중간에 제가했던 사전작업과 위상정렬을 동시에 해버리면된다.

 

어떤 상황에서 큐에 들어가는지와 어떤상황에서 종료를 해야하는지 파악하면 풀 수 있는 문제이다.

 

즉 들어온 입력들 전체를 while문을 돌리면서

 

단 한번도 최저값이 갱신된적이 없거나, 새로운 재료가 생기지 않았다면, 종료를 해주면 된다.

 

 

 

 

 

 

 

 

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 1799 비숍  (0) 2021.06.18
[BOJ/백준] 2233 사과 나무  (0) 2021.06.14
[BOJ/백준] 15653 구슬 탈출 4  (0) 2021.06.14
[BOJ/백준] 21944 문제 추천 시스템 Version2  (0) 2021.06.11
[BOJ/백준] 21943 연산 최대로  (0) 2021.06.11
import sys
from collections import deque
input = sys.stdin.readline

def bfs(red,blue):
    stack = deque()

    stack.append((*red,*blue,0))
    dx = [-1,1,0,0]
    dy = [0,0,-1,1]
    while stack:
        rx,ry,bx,by,dis = stack.popleft()

        for i in range(4):
            nrx,nry = rx,ry
            r_p = 0
            while 0<=nrx<N and 0<=nry<M and arr[nrx][nry] != '#' and arr[nrx][nry] != 'O':
                nrx += dx[i]
                nry += dy[i]
                r_p += 1
            nbx,nby = bx,by
            b_p = 0
            while 0<=nbx<N and 0<=nby<M and arr[nbx][nby] != '#' and arr[nbx][nby] != 'O':
                nbx += dx[i]
                nby += dy[i]
                b_p += 1

            if (nbx,nby) == (nrx,nry):
                if arr[nbx][nby] == 'O':
                    continue
                if r_p > b_p:
                    nrx -= dx[i]
                    nry -= dy[i]
                else:
                    nbx -= dx[i]
                    nby -= dy[i]

            elif arr[nbx][nby] == 'O':
                continue
            elif arr[nrx][nry] == 'O':
                return dis+1
            nrx -= dx[i]
            nry -= dy[i]
            nbx -= dx[i]
            nby -= dy[i]
            if not visited[nrx][nry][nbx][nby]:continue
            visited[nrx][nry][nbx][nby] = False
            stack.append((nrx,nry,nbx,nby,dis+1))
    return -1

N,M = map(int,input().split())


arr = []
blue = []
red = []
for x in range(N):
    temp = list(input())
    for y in range(M):
        if temp[y] == 'B':
            blue = (x,y)
        elif temp[y] == 'R':
            red = (x,y)
    arr.append(temp)


visited = [[[[True for _ in range(M)] for _ in range(N)] for _ in range(M)] for _ in range(N)]


result = bfs(red,blue)
print(result)

 

 

구슬 탈출의 마지막 시리즈다.

 

구슬탈출3에서 썼던 코드에서 경로추적이랑 10이상일때 종료인것만 제외시켜줬다.

 

구슬탈출1~4까지는 다 똑같은 코드이므로,

 

하나만 잘 풀어놓으면

 

코드를 1~3군데만 고쳐도 전부 통과할수 있다.

+ Recent posts