알고리즘 문제 풀이/Python

python 백준 2887 행성 터널

맛대 2023. 6. 2. 11:29

MST(최소 신장 트리)문제

전체적인 로직

  • x 좌표, y좌표, z좌표를 따로 저장(행성 번호와 함께 저장)

  • KRUSKAL 알고리즘 이용

  • x,y,z 좌표를 정렬한 뒤, 앞 뒤 행성간의 거리 측정해서 저장(먼 거리의 행성은 끼리 계산할 필요 x)

  • 측정된 거리 순서로 정렬

  • 정렬 순서로 연결

코드 설명

union 함수

  • a행성과 b행성을 연결하는 함수
  • depth : 행성 연결이 지름이 길게 연결되지 않고 균등하게 붙어지게 하기 위한 최적화

find_root 함수

  • n번 행성(node)가 어디로 연결되어 있는지 확인하는 함수
  • group[n] = find_root(T) 은 find_root를 자주 사용할 경우 함수 호출을 줄이기 위해 최적화

find_min_length 함수

  • 인자로 받은 ls중 제일 작은 것의 index를 반환하는 함수
  • ls[0]은 x좌표 행성간의 거리, ls[1]은 y좌표 행성간의 거리, ls[2]는 z좌표 행성간의 거리

gap_ls 함수

  • 인자로 받은 list(좌표 기록된 list)를 이용하여 앞,뒤 행성간의 거리를 list로 만들어 반환하는 함수
  • 각 행성의 index를 기록하기 위해 tuple 형태로 앞,뒤 행성도 같이 기록

solution의 while

  • 현재 x,y,z 좌표를 봤을때 가장 짧은 두 행성을 찾고 연결 하는 반복문
  • 연결 후 ls의 값(x,y,z 앞 뒤 행성간의 거리)들을 확인 하고, 만약 해당 값의 행성들이 연결된 경우 다음 좌표의 행성을 ls에 갱신
  • 이 때 갱신은 KRUSKAL 알고리즘으로 다음으로 짧은 두 행성간의 거리로 갱신이 됨
import sys
input = sys.stdin.readline

def union(a:int,b:int)->None:
    A = find_root(a)
    B = find_root(b)

    if depth[A] > depth[B]:
        group[B] = A
    elif depth[A] < depth[B]:
        group[A] = B
    else:
        group[B] = A
        depth[A] += 1

def find_root(n:int)->int:
    T = group[n]
    if T == n:
        return n
    group[n] = find_root(T)
    return group[n]

def find_min_length(ls:list[int,int,int])->int:
    a,b,c = ls[0][0],ls[1][0],ls[2][0]
    if a <= b and a <= c:
        return 0
    elif b <= a and b <= c:
        return 1
    else:
        return 2

def gap_ls(ls:list[int])->list[int]:
    ls.sort()
    l = len(ls)
    value = []
    for i in range(l-1):
        value.append((ls[i+1][0] - ls[i][0],ls[i+1][1],ls[i][1]))
    value.sort()
    return value

def solution(N:int,x_point:list[int],y_point:list[int],z_point:list[int])->int:
    if N == 1:
        return 0
    n = N - 1
    value = 0
    x_gap = gap_ls(x_point)
    y_gap = gap_ls(y_point)
    z_gap = gap_ls(z_point)

    xyz = [x_gap,y_gap,z_gap]
    index = [0,0,0]
    ls = [xyz[i][0] for i in range(3)]

    while True:
        T = find_min_length(ls)
        union(ls[T][1],ls[T][2])
        value += ls[T][0]

        for i in range(3):
            while True:
                a,b = xyz[i][index[i]][1],xyz[i][index[i]][2]
                if find_root(a) == find_root(b):
                    index[i] += 1
                    if index[i] == n:
                        return value
                else:
                    break
            ls[i] = xyz[i][index[i]]

N = int(input())
x_point,y_point,z_point = list(),list(),list()
for i in range(N):
    x,y,z = map(int,input().split())
    x_point.append((x,i))
    y_point.append((y,i))
    z_point.append((z,i))
group = [num for num in range(N)]
depth = [0]*(N)
ans = solution(N,x_point,y_point,z_point)
print(ans)

'알고리즘 문제 풀이 > Python' 카테고리의 다른 글

python 백준 1516 게임 개발  (0) 2023.07.11
python 백준 1461 도서관  (0) 2023.07.05
python 백준 1939 중량제한  (0) 2023.04.14
python 백준 9655 돌 게임  (0) 2023.03.30
python 백준 10986 나머지 합  (0) 2023.03.21