import heapq

def solution(n, start, end, roads, traps):
    answer = 0
    INF = float('inf')
    dp = [[INF for _ in range(n+1)] for _ in range(1<<len(traps))]
    traps_index ={value:ind  for ind,value in enumerate(traps) }
    node_list = []
    graph =[[] for _ in range(n+1)]

    for road in roads:
        x,y,pay = road
        graph[x].append([y,pay,0])
        graph[y].append([x,pay,1])

    heapq.heappush(node_list,(0,start,0))
    dp[0][start] = 0
    while node_list:
        cur_time,cur_node,state = heapq.heappop(node_list)
        if end == cur_node:
            answer = cur_time
            break
        if dp[state][cur_node] < cur_time:
            continue
        for next_node,pay,flag in graph[cur_node]:
            next_state = state
            if cur_node in traps:
                if next_node in traps:
                    cur_flag = ((1&(state>>traps_index[cur_node])) + 
                    (1&(state>>traps_index[next_node])))%2
                    next_state = state^(1<<traps_index[next_node])
                else:
                    cur_flag = (1&(state>>traps_index[cur_node]))
            else:
                if next_node in traps:
                    cur_flag =  (1&(state>>traps_index[next_node]))
                    next_state = state^(1<<traps_index[next_node])
                else:
                    cur_flag = 0
            
            if cur_flag == flag:
                if dp[next_state][next_node] > cur_time + pay:
                    dp[next_state][next_node] = cur_time + pay
                    if next_node in traps:
                        heapq.heappush(node_list,(dp[next_state][next_node], next_node,next_state ))
                    else:
                        heapq.heappush(node_list,(dp[next_state][next_node],next_node,next_state))

    return answer

이 문제를 핵심은 트랩을 밟았을 때의 상태를 어떻게 저장할 것인가.

 

저장한 그 상태를 통해 현재 길의 상태를 어떻게 알 수 있을가인가.

 

이 문제는 총 4가지 경우가 있다고 볼 수 있다.

 

트랩이 아닌  곳 -> 트랩이 아닌곳

 

트랩이 아닌 곳 -> 트랩인 곳

 

트랩인 곳 - > 트랩이 아닌곳

 

트랩인 곳 -> 트랩인 곳

 

이렇게 총 4가지 경우가 있을거고, 각각의 상태를 구분해서, 하면 된다.

 

우리는 먼저 주어진 경로들을 2가지 형태로 저장해놔야한다.

 

원래 주어진 정방향과, 이게 뒤집혔을때 가는 역방향이다. 이걸 나는 (다음노드, 시간, 상태) 로 넣어놨다.

 

0이면 정방향

 

1이면 역방향을 의미하는 식으로 넣어놨다.

 

트랩은 최대10개 밖에 없고, 트랩의 상태는 비트마스킹을 통해 나타낼수있다.

 

그렇게 하기 위해 각각의 trap들의 index로 각자의 위치를 매핑시켜줬다.

 

그러면 인제부터 각위치의 비트가 1이면 해당 위치의 트랩이 활성화 된 상황이고, 비활성화일때는 0이다.

 

그리고 우리는 각 상태에서 다익스트라를 돌리는 것이기 때문에 최대 (2**10) * N개의 distance 테이블을 만들어놓고 탐색을 하면 된다.

 

원래의 다익스트라와 비슷하게 하지만, 우리는 state를 통해 현재 길을 이용할 수 있는지 없는지 판별을 해줘야한다.

 

한 그래프에 정방향, 역방향을 전부 넣어놨기 때문에, state를 통해서 해당 길을 이용할 수 있는지 찾아야한다.

 

먼저 가장 쉬운

 

현재 노드가 트랩이 아닐때이다.

 

만약 다음 노드가 트랩이 아니라면, 이 길은 항상 정방향이다. 그러므로 저장해놓길 0으로 저장해놓은 길들만 이용이 가능하다.

 

만약 다음 노드가 트랩이라면, state를 통해 해당 트랩이 활성화 되어있는지 판별을 하면 된다.

 

판별하는 방법은 현재 state를 해당 트랩의 index만큼 >> 오른쪽으로 비트 이동을 시킨다. 그리고 1과 & 연산을 하면 된다.

 

이게 1이면 활성화가 된 상태이고, 아니면 비활성화 상태이다.

 

 

다음은 현재노드가 트랩일때이다.

 

현재노드가 트랩이고, 다음노드가 트랩이 아닐때에는 위와 동일하므로 넘어가겠다.

 

 

가장 많이 고민했던, 현재노드와 다음노드가 전부 트랩일때이다.

 

이때는 총 4가지 경우가 생긴다.

 

1. 현재노드 비활성화 다음노드 비활성화

2 .현재노드 활성화 다음노드 비활성화

3. 현재노드 비활성화 다음노드 활성화

4. 현재노드 활성화 다음노드 활성화

 

총 4가지 경우가 생긴다.

 

1번의 경우는 둘다 활성화가 안되어있기 때문에 정방향인 길들만 가면된다.

 

2,3번은 동일하고, 한쪽만 활성화 되어있으면, 그 길은 한번 뒤집힌거기 때문에 역방향인 길들만 가면 된다.

 

4번은 현재노드 다음노드 전부 활성화 되어있을때이다.

이때는 이 길이 2번 뒤집힌거기 때문에, 정방향인 길들만 가면된다.

 

그래서 나는 

 

(    ( 1 & (state>>traps_index[cur_node]) ) + ( 1&( state>>traps_index[next_node]) ) )%2

2개의 상태를 더해서 2로 나눈 나머지로 해서 정방향, 역방향을 구분을 해줬다. 이 외에도 xor 연산을 해도 된다.

 

 

우리는 이렇게 해서 해당 길이 갈 수 있는 길인지 아닌지 판별을 했다.

 

나머지는 다익스트라와 동일하게 하면되는데, 

 

 

 

그리고 현재의 상태와 비교하는게 아니라, 다음의 상태와 비교를 해서 판별을 해주었다.

 

만약 다음 노드가 트랩이라 한다면, 그 트랩은 활성화가 된 상태인거고, 그때 최소값으로 들어가야한다고 생각했다

 

그래서 각 현재 시간과 다음노드까지 걸리는 시간은 현재 state가 아닌, 만약 다음노드가 트랩이라면, 트랩의 상태에 따라, 변화된 state로 비교를 해서 구현을 했다.

 

 

import heapq
def solution(n, k, cmds):
    answer = ''
    def inverse(num):
        return -num
    max_heap = list(map(inverse,range(k)))
    min_heap = list(range(k,n))
    deleted = ['O' for _ in range(n)]
    deleted_stack = []
    heapq.heapify(max_heap)
    heapq.heapify(min_heap)
    for cmd in cmds:
        command = cmd.split()
        if len(command)>1:
            num = command[1]
            command = command[0]
            num = int(num)
            if command == 'D':
                for _ in range(num):
                    heapq.heappush(max_heap,-heapq.heappop(min_heap))
            else:
                for _ in range(num):
                    heapq.heappush(min_heap,-heapq.heappop(max_heap))
        else:
            command = command[0]
            if command == 'C':
                delete_num = heapq.heappop(min_heap)
                deleted_stack.append(delete_num)
                deleted[delete_num] = 'X'
                if len(min_heap) == 0:
                    heapq.heappush(min_heap,-heapq.heappop(max_heap))
            else:
                restore_num = deleted_stack.pop()
                deleted[restore_num] = 'O'
                if min_heap[0] > restore_num:
                    heapq.heappush(max_heap,-restore_num)
                else:
                    heapq.heappush(min_heap,restore_num)
    answer = ''.join(deleted)
    return answer

 

2021 카카오 채용연계형 인턴십에서 나왔던 3번 문제인 표 편집을 다시 풀어보았다.

 

그때는 효율성에서 통과 못했던 문제였던지라, 문제가 공개되자마자 푼 문제이다.

 

시험이 끝난 후 생각했던 풀이를 옮기는 작업으로 많이 걸리지 않았지만, 바보같이 python2로 제출을 해서 오래걸렸다.

 

첫 풀이는

https://www.acmicpc.net/problem/1655

 

1655번: 가운데를 말해요

첫째 줄에는 수빈이가 외치는 정수의 개수 N이 주어진다. N은 1보다 크거나 같고, 100,000보다 작거나 같은 자연수이다. 그 다음 N줄에 걸쳐서 수빈이가 외치는 정수가 차례대로 주어진다. 정수는 -1

www.acmicpc.net

백준의 가운데를 말해요 와

 

https://www.acmicpc.net/problem/21944

 

21944번: 문제 추천 시스템 Version 2

recommend, recommend2, recommend3 명령이 주어질 때마다 문제 번호를 한 줄씩 출력한다. 주어지는 recommend, recommend2, recommend3 명령어의 총 개수는 최소 1개 이상이다.

www.acmicpc.net

 문제 추천 시스템 version2에서 풀었던 풀이를 이용했다.

 

가운데를 말해요에서 처럼 가운데에 유지해야할 인수를 선택된 행으로 해주면 된다.

 

나는 그걸 min_heap의 최저값으로 유지를 해줬다.

 

max_heap과 min_heap 2개로 나뉜 상태에서 문제에서 주어진 k보다 크거나 같은 수는 min_heap에 그대로 넣어줘서, 저장을 해줬고, k보다 작은 수는 max_heap에 -를 곱해서 넣어줬다.

 

heap을 넣었다 뺐다하면, 시간이 오래걸리니 heapq의 heapify 메소드를 이용해 한번에 바꿔주었다.

 

이 상태에서 min_heap의 최저값이 우리가 선택한 인덱스가 유지해주게 되면 된다.

 

U 명령어와 같은경우엔 우리가 선택한 인덱스가 줄어들어야한다.

 

그래서 max_heap에서 인수를 꺼내서 -1을 곱한뒤 min_heap에 넣어준다.

 

그와 반대로 D 명령어는 우리가 선택한 인덱스가 늘어나야하므로,

 

min_heap에서 인수를 꺼내서 -1을 곱한뒤 max_heap에 넣어준다.

 

다음으로 지우는 명령어인 C가 주의해야한다.

 

먼저 C를 시행하는 방법은 우리가 봤던 index를 없애는 것이므로, min_heap에서 빼준뒤, 그 값을 스택에 넣어주고,

 

상태를 변화를 해준다.

 

이대로 그대로 하면 min_heap이 아무것도 안남는경우가 생길것이다. 즉 길이가 5인데 만약에 k가 4인 상태에서

 

C명령어를 수행하면 min_heap이 전부 비워지게 된다. 그러면 가리키는 인덱스가 없어지는 것이므로, 

 

max_heap에서 하나를 꺼내와서 min_heap에 넣어주는 작업을 해주면 된다.

 

 

Z 명령어는 간단하다.

 

우리가 삭제한 변수들을 저장한 스택에서 pop을 한뒤 그 인자가

 

현재 우리가 가리키는 인덱스 보다 작으면  max_heap에 넣어주고 크면 min_heap에 넣어주면 된다.

 

이렇게 한뒤 마지막으로 상태를 join을 한뒤 돌려주면 된다.

 

 

class Node():
    def __init__(self,data):
        self.prev = None
        self.next = None
        self.data = data
        

def solution(n,k,cmds):
    node_list = [Node(0)]
    deleted_stack = []
    deleted_state = ['O' for _ in range(n)]
    for num in range(1,n):
        prev_num = node_list[num-1]
        cur_num = Node(num)
        prev_num.next = cur_num
        cur_num.prev = prev_num
        node_list.append(cur_num)
    
    cur_node = node_list[k]

    for cmd in cmds:
        command = cmd.split()
        if len(command)>1:
            num = int(command[1])
            command = command[0]
            if command =='D':
                for _ in range(num):
                    cur_node = cur_node.next
            else:
                for _ in range(num):
                    cur_node = cur_node.prev
        else:
            command = command[0]
            if command == 'C':
                prev_num = cur_node.prev
                next_num = cur_node.next
                if next_num == None:
                    prev_num.next = None
                    deleted_stack.append(cur_node)
                    deleted_state[cur_node.data] = 'X'
                    cur_node = prev_num
                elif prev_num == None:
                    next_num.prev = None
                    deleted_stack.append(cur_node)
                    deleted_state[cur_node.data] = 'X'
                    cur_node = next_num
                else:
                    prev_num.next = next_num
                    next_num.prev = prev_num
                    deleted_stack.append(cur_node)
                    deleted_state[cur_node.data] = 'X'
                    cur_node = next_num
            else:
                restore_node = deleted_stack.pop()
                prev_num = restore_node.prev
                next_num = restore_node.next
                if prev_num != None:
                    prev_num.next = restore_node
                if next_num != None:
                    next_num.prev = restore_node
                deleted_state[restore_node.data] = 'O'
    answer = ''.join(deleted_state)
    return answer

 

 

두번째는 해설에도 나온 링크드리스트를 활용해서 풀면된다.

 

여기서 주의해야할 점은 prev나 next가 None일때 예외 처리르 어떻게 해주는지 이다. 그 외에는 일반적으로 하면 된다.

 

좀 더 개선한 코드는 

 

class Node():
    def __init__(self,data):
        self.prev = None
        self.next = None
        self.data = data
        

def solution(n,k,cmds):
    node_list = [Node(0)]
    deleted_stack = []
    deleted_state = ['O' for _ in range(n)]
    for num in range(1,n):
        prev_num = node_list[num-1]
        cur_num = Node(num)
        prev_num.next = cur_num
        cur_num.prev = prev_num
        node_list.append(cur_num)
    
    cur_node = node_list[k]

    for cmd in cmds:
        command = cmd.split()
        if len(command)>1:
            num = int(command[1])
            command = command[0]
            if command =='D':
                for _ in range(num):
                    cur_node = cur_node.next
            else:
                for _ in range(num):
                    cur_node = cur_node.prev
        else:
            command = command[0]
            if command == 'C':
                prev_num = cur_node.prev
                next_num = cur_node.next
                deleted_stack.append(cur_node)
                deleted_state[cur_node.data] = 'X'
                if next_num != None:
                    next_num.prev = prev_num
                if prev_num != None:
                    prev_num.next = next_num
                if next_num != None:
                    cur_node = next_num
                else:
                    cur_node = prev_num
            else:
                restore_node = deleted_stack.pop()
                prev_num = restore_node.prev
                next_num = restore_node.next
                if prev_num != None:
                    prev_num.next = restore_node
                if next_num != None:
                    next_num.prev = restore_node
                deleted_state[restore_node.data] = 'O'
    answer = ''.join(deleted_state)
    return answer

이 코드이다.

 

정확한 해설은 이보다 https://tech.kakao.com/2021/07/08/2021-%EC%B9%B4%EC%B9%B4%EC%98%A4-%EC%9D%B8%ED%84%B4%EC%8B%AD-for-tech-developers-%EC%BD%94%EB%94%A9-%ED%85%8C%EC%8A%A4%ED%8A%B8-%ED%95%B4%EC%84%A4/

 

2021 카카오 인턴십 for Tech developers 코딩 테스트 해설

2021년 카카오의 여름 인턴십의 첫 번째 관문인 코딩 테스트가 지난 2021년 5월 8일에 4시간에 걸쳐 진행되었습니다. 이번 인턴 코딩 테스트에서는 5문제가 출제되었습니다. 이전과 동일하게 쉬운

tech.kakao.com

 

여기가 잘 설명되어있다.

from collections import defaultdict

N = int(input())

arr = list(input())

cnt_dict = defaultdict(int)
result = 0
prev_idx = -1
for idx in range(N):
    alpha = arr[idx]
    cnt_dict[alpha] = idx
    if len(cnt_dict) > N:
        min_idx = 100001
        min_key = -1
        for key in cnt_dict:
            if min_idx > cnt_dict[key]:
                min_idx = cnt_dict[key]
                min_key = key
        prev_idx = min_idx
        del cnt_dict[min_key]
    
    result = max(result,idx-prev_idx)
print(result)

 

 

전형적인 두 포인터 문제이고, 그걸 다른 방식으로 푼 것이다. 

 

defalutdict를 통해, 각 알파벳의 마지막 위치를 저장시켜주고,

 

그 길이가 N을 넘어서게 되면, 마지막 위치가 가장 작은 알파벳을 삭제해주는 방식으로 해주고

 

prev_idx를 갱신해준다

 

위와 같은 방식을 통해 문제를 풀어주면 된다.

 

 

 

   . 

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 4358 생태학  (0) 2021.07.12
[BOJ/백준] 16947 서울 지하철 2호선  (0) 2021.07.12
[BOJ/백준] 16398 행성 연결  (0) 2021.06.29
[BOJ/백준] 15661 링크와 스타트  (0) 2021.06.29
[BOJ/백준] 10711 모래성  (0) 2021.06.29
import sys
import heapq

def input():
    return sys.stdin.readline().rstrip()


N = int(input())

graph = [list(map(int,input().split())) for _ in range(N)]

node_list = []

INF = float('inf')
distance_list = [INF for _ in range(N)]
visited = [False for _ in range(N)]
heapq.heappush(node_list,(0,0))
distance_list[0] = 0
result = 0
while node_list:
    cur_dis,cur_node = heapq.heappop(node_list)
    if visited[cur_node]:continue
    if cur_dis > distance_list[cur_node]:continue
    result += cur_dis
    visited[cur_node] = True
    for next_node in range(N):
        if next_node == cur_node:continue
        if visited[next_node]:continue
        if distance_list[next_node] > graph[cur_node][next_node]:
            distance_list[next_node] = graph[cur_node][next_node]
            heapq.heappush(node_list,(distance_list[next_node],next_node))

print(result)

 프림 알고리즘

 

 

import sys

def input():
    return sys.stdin.readline().rstrip()

def find_parents(X):
    if X == make_set[X]:
        return X
    make_set[X] = find_parents(make_set[X])
    return make_set[X]


def union(x,y):
    X = find_parents(x)
    Y = find_parents(y)

    if X !=Y:
        if ranks[X]< ranks[Y]:
            X,Y = Y,X
        make_set[Y] = X
        if ranks[X] == ranks[Y]:
            ranks[X] += 1
        return True
    return False



N = int(input())
edge_list = []
for x in range(N):
    temp = list(map(int,input().split()))
    for y in range(N):
        if x == y:continue
        edge_list.append((temp[y],x,y))


edge_list.sort(reverse=True)

cnt = 1
make_set = [i for i in range(N)]
ranks = [1 for _ in range(N)]
result = 0
while cnt <N:
    dis,node_a,node_b = edge_list.pop()
    if union(node_a,node_b):
        cnt += 1
        result += dis

print(result)

 

크루스칼 알고리즘

 

 

전형적인 MST 문제이다. MST 문제를 푸는 방식대로 풀면 된다.

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 16947 서울 지하철 2호선  (0) 2021.07.12
[BOJ/백준] 16472 고냥이  (0) 2021.06.29
[BOJ/백준] 15661 링크와 스타트  (0) 2021.06.29
[BOJ/백준] 10711 모래성  (0) 2021.06.29
[BOJ/백준] 2463 비용  (0) 2021.06.29
import sys
from itertools import combinations
def input():
    return sys.stdin.readline().rstrip()


N = int(input())


arr = [list(map(int,input().split())) for _ in range(N)]
row = [sum(i) for i in arr]
col = [sum(i) for i in zip(*arr)]
new_arr = [i+ j for i, j in zip(row, col)]
total_sum = sum(new_arr)//2
result = float('inf')
for num in range(1,N//2+1):
    for combi in combinations(new_arr,num):
        result = min(result,abs(total_sum-sum(combi)))
        if result == 0:
            break
    if result == 0:
        break
print(result)

이 문제는 과거에 푼 boj.kr/14889 이 있었기 때문에 빨리 풀 수 있었다. 

 

기본 푼 방식은 원래와 똑같다. 각 col,row의 섬을 미리 구해놓고, 그걸 통해서 문제를 푸는 방식이다.

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 16472 고냥이  (0) 2021.06.29
[BOJ/백준] 16398 행성 연결  (0) 2021.06.29
[BOJ/백준] 10711 모래성  (0) 2021.06.29
[BOJ/백준] 2463 비용  (0) 2021.06.29
[BOJ/백준] 1398 동전  (0) 2021.06.29
import sys
from collections import deque
def input():
    return sys.stdin.readline().rstrip()


N,M = map(int,input().split())


arr = [list(input()) for _ in range(N)]

sand_deque = deque()

dx = [-1,0,1,-1,1,-1,0,1]
dy = [-1,-1,-1,0,0,1,1,1]
for x in range(N):
    for y in range(M):
        if not arr[x][y].isdigit():
            sand_deque.append((x,y))
        else:
            arr[x][y] = int(arr[x][y])



time = 0

while True:
    remove_sand = deque()

    while sand_deque:
        x,y = sand_deque.popleft()
        for i in range(8):
            nx = x + dx[i]
            ny = y + dy[i]
            if not (0<=nx<N and 0<=ny<M):continue
            if arr[nx][ny] != '.':
                arr[nx][ny] -= 1
                if arr[nx][ny] == 0:
                    remove_sand.append((nx,ny))
                    arr[nx][ny] = '.'
    if remove_sand:
        sand_deque.extend(remove_sand)
        time += 1
    else:
        break

print(time)

 

 

처음에 dict을 이용해서 날먹할려다가 시간초과가 났던 문제이다.

 

문제를 푸는 방식은 다음과 같다. 모든 모래들은 하나의 큐에 넣은 뒤에, 8방향의 모래가 아닌곳의 개수를 -1씩 해준다.

 

그러면서 0이 되면, remove_sand라는 큐에 넣어준다.

 

그러면 이 모래성이 더이상 무너지지 않는 경우는 remove_sand가 비어있을때이다.

 

그래서 remove_sand가 비어있으면 break를 해주고, 있을때에는 sand_deque에 extend 해주고 그 때 time을 1 늘려준다.

 

 

import sys
from collections import deque
def input():
    return sys.stdin.readline().rstrip()


def bfs(queue):
    dx = [-1,0,1,-1,1,-1,0,1]
    dy = [-1,-1,-1,0,0,1,1,1]
    while queue:
        x,y,cnt = queue.popleft()

        for i in range(8):
            nx = x + dx[i]
            ny = y + dy[i]
            if not(0<=nx<N and 0<=ny<M):continue
            if arr[nx][ny]:
                arr[nx][ny] -= 1
                if not arr[nx][ny]:
                    queue.append((nx,ny,cnt+1))

    return cnt

N,M = map(int,input().split())


arr = [list(input()) for _ in range(N)]


sand_deque = deque()
for x in range(N):
    for y in range(M):
        if arr[x][y].isdigit():
            arr[x][y] = int(arr[x][y])
        else:
            arr[x][y] = 0
            sand_deque.append((x,y,0))



print(bfs(sand_deque))

 이 코드는 좀 더 깔끔한 방식으로 코드를 바꾼 방식이다.

 

원리 자체는 똑같다.    

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 16398 행성 연결  (0) 2021.06.29
[BOJ/백준] 15661 링크와 스타트  (0) 2021.06.29
[BOJ/백준] 2463 비용  (0) 2021.06.29
[BOJ/백준] 1398 동전  (0) 2021.06.29
[BOJ/백준] 1045 도로  (0) 2021.06.29
import sys

def input():
    return sys.stdin.readline().rstrip()
def find_parents(X):
    if X == make_set[X]:
        return X
    make_set[X] = find_parents(make_set[X])
    return make_set[X]


def union(x,y):
    X = find_parents(x)
    Y = find_parents(y)

    if X !=Y:
        if ranks[X]< ranks[Y]:
            X,Y = Y,X
        make_set[Y] = X
        groups[X].extend(groups[Y])
        if ranks[X] == ranks[Y]:
            ranks[X] += 1
        return True
    return False
N,M = map(int,input().split())


edge_list = []
total_pay = 0
for _ in range(M):
    x,y,pay = map(int,input().split())
    if x>y:
        x,y = y,x
    edge_list.append((x,y,pay))
    total_pay += pay
K = 10**9
past_pay = 0
answer = 0
make_set = [i for i in range(N+1)]
ranks = [1 for _ in range(N+1)]
groups = [[k] for k in range(N+1)]
edge_list.sort(key=lambda x : x[2])
while edge_list:
    x,y,pay = edge_list.pop()
    p_x,p_y = find_parents(x),find_parents(y)
    if p_x!= p_y:
        answer = answer + len(groups[p_x])*len(groups[p_y])*(total_pay-past_pay)
        answer = answer%K
        union(x,y)
    past_pay += pay
print(answer)

 

구사과님의 분리집합 추천 문제에 있어서, 풀었던 문제였는데 푸는데 어려웠던 문제였다. 

 

이 문제는 분리집합만을 이요해 푸는 문제인데, 푸는 아이디어는 다음과 같다.

 

모든 간선이 연결이 끊어져 있다고 생각을 하고, 간선의 비용이 가장 많은 비용 순으로 연결을 해주는 것이다.

 

 

문제에 주어진 Cost(u,v)는 u,v가 연결이 안될때까지 최소비용의 간선을 제거해주는 것이다.

 

즉 역으로 생각하면  간선의 비용이 가장 많은 비용 순으로 연결을 해줄때, 처음 u,v가 같은 집합이 되면,

 

그 전까지 연결된 비용을 제외한 나머지 비용들이, 이 u,v의 Cost가 될 수 있다.

 

 

즉 전체 간선의 연결비용을 구한 상태에서 비용이 높은 순으로 연결을 해주면서 연결이 되었을때 까지의 누적 비용을

 

뺀 값이 해당 cost(u,v)인 것을 알 수 있게 된다.

 

그러면 여기서 고민사항은 다음과 같다.

 

서로 단독집합일때에는 위와 같이 한다고 할때, 만약에 서로 다른 집합에 이미 속해있는 상황에서는 어떻게 해야할지, 

 

일 것이다.

 

이 문제는 문제에서 cost(u,v)에서 u<v 이므로, 우리는 한가지 정보를 더 저장을 시켜줘야한다.

 

그것은 부모에 해당 집합에 속해있는 노드를 저장시켜놓는것이다.

 

그래서 우리는 각 부모들에 노드를 저장시켜주고, 그 노드의 수를 서로 곱해주면 두 집합에서 나올 수 있는

 

u,v의 집합의 개수와 동일 하다.

 

정리하면

 

(total_pay - past_pay) *len(group[parent_u]) *len(group[parent_v]) 로 나타낼수 있다.

 

 

위의 코드를 통해 AC를 받을 수 있었다. 이 문제의 아이디어를 떠오르는데 힘들었고, 구현을 하는데 어려움이 많았다.

 

아직 분리집합에 대해 잘 모르고, 응용하는법에 대해 잘 몰랐다는 것을 알게 되었다.

 

 

 

 

import sys

def input():
    return sys.stdin.readline().rstrip()
def find_parents(X):
    if X == make_set[X]:
        return X
    make_set[X] = find_parents(make_set[X])
    return make_set[X]


def union(x,y):
    X = find_parents(x)
    Y = find_parents(y)

    if X !=Y:
        if ranks[X]< ranks[Y]:
            X,Y = Y,X
        make_set[Y] = X
        groups[X] +=groups[Y]
        if ranks[X] == ranks[Y]:
            ranks[X] += 1
        return True
    return False
N,M = map(int,input().split())


edge_list = []
total_pay = 0
for _ in range(M):
    x,y,pay = map(int,input().split())
    if x>y:
        x,y = y,x
    edge_list.append((x,y,pay))
    total_pay += pay
K = 10**9
past_pay = 0
answer = 0
make_set = [i for i in range(N+1)]
ranks = [1 for _ in range(N+1)]
groups = [ 1 for _ in range(N+1)]
edge_list.sort(key=lambda x : x[2])
while edge_list:
    x,y,pay = edge_list.pop()
    p_x,p_y = find_parents(x),find_parents(y)
    if p_x!= p_y:
        answer = answer + groups[p_x]*groups[p_y]*(total_pay-past_pay)
        answer = answer%K
        union(p_x,p_y)
    past_pay += pay
print(answer)

 

 

이 코드는 노드를 전부 저장하는게 아닌, 노드의 개수를 저장해주는 방식으로 바꾼 코드이다.

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 15661 링크와 스타트  (0) 2021.06.29
[BOJ/백준] 10711 모래성  (0) 2021.06.29
[BOJ/백준] 1398 동전  (0) 2021.06.29
[BOJ/백준] 1045 도로  (0) 2021.06.29
[BOJ/백준] 2250 트리의 높이와 너비  (0) 2021.06.22
import sys

def input():
    return sys.stdin.readline().rstrip()

T = int(input())
dp = [i for i in range(101)]

for i in range(100):
    for coin in [10,25]:
        if i +coin > 100:continue
        dp[i+coin] = min(dp[i+coin],dp[i]+1)
for _ in range(T):
    N = int(input())

    answer = 0
    while N>0:
        answer += dp[N%100]
        N//=100
    print(answer)



 

 

이 문제는 100까지의 최소 코인의 개수를 구해준 뒤에, 100단위로 나눠주면서 그 개수를 더해주면 된다.

 

왜냐하면

 

이 문제에서 쓰이는 동전은 [1,10,25]로 되어있고

 

각 100^k 만 곱해진 동전이기 때문이다.

 

 

'알고리즘 > 백준' 카테고리의 다른 글

[BOJ/백준] 10711 모래성  (0) 2021.06.29
[BOJ/백준] 2463 비용  (0) 2021.06.29
[BOJ/백준] 1045 도로  (0) 2021.06.29
[BOJ/백준] 2250 트리의 높이와 너비  (0) 2021.06.22
[BOJ/백준] 2213 트리의 독립집합  (0) 2021.06.22

+ Recent posts