MST(최소 신장 트리)문제
전체적인 로직
x 좌표, y좌표, z좌표를 따로 저장(행성 번호와 함께 저장)
KRUSKAL 알고리즘 이용
x,y,z 좌표를 정렬한 뒤, 앞 뒤 행성간의 거리 측정해서 저장(먼 거리의 행성은 끼리 계산할 필요 x)
측정된 거리 순서로 정렬
정렬 순서로 연결
코드 설명
union 함수
- a행성과 b행성을 연결하는 함수
- depth : 행성 연결이 지름이 길게 연결되지 않고 균등하게 붙어지게 하기 위한 최적화
find_root 함수
- n번 행성(node)가 어디로 연결되어 있는지 확인하는 함수
group[n] = find_root(T)
은 find_root를 자주 사용할 경우 함수 호출을 줄이기 위해 최적화
find_min_length 함수
- 인자로 받은 ls중 제일 작은 것의 index를 반환하는 함수
ls[0]
은 x좌표 행성간의 거리,ls[1]
은 y좌표 행성간의 거리,ls[2]
는 z좌표 행성간의 거리
gap_ls 함수
- 인자로 받은 list(좌표 기록된 list)를 이용하여 앞,뒤 행성간의 거리를 list로 만들어 반환하는 함수
- 각 행성의 index를 기록하기 위해 tuple 형태로 앞,뒤 행성도 같이 기록
solution의 while
- 현재 x,y,z 좌표를 봤을때 가장 짧은 두 행성을 찾고 연결 하는 반복문
- 연결 후 ls의 값(x,y,z 앞 뒤 행성간의 거리)들을 확인 하고, 만약 해당 값의 행성들이 연결된 경우 다음 좌표의 행성을 ls에 갱신
- 이 때 갱신은 KRUSKAL 알고리즘으로 다음으로 짧은 두 행성간의 거리로 갱신이 됨
import sys
input = sys.stdin.readline
def union(a:int,b:int)->None:
A = find_root(a)
B = find_root(b)
if depth[A] > depth[B]:
group[B] = A
elif depth[A] < depth[B]:
group[A] = B
else:
group[B] = A
depth[A] += 1
def find_root(n:int)->int:
T = group[n]
if T == n:
return n
group[n] = find_root(T)
return group[n]
def find_min_length(ls:list[int,int,int])->int:
a,b,c = ls[0][0],ls[1][0],ls[2][0]
if a <= b and a <= c:
return 0
elif b <= a and b <= c:
return 1
else:
return 2
def gap_ls(ls:list[int])->list[int]:
ls.sort()
l = len(ls)
value = []
for i in range(l-1):
value.append((ls[i+1][0] - ls[i][0],ls[i+1][1],ls[i][1]))
value.sort()
return value
def solution(N:int,x_point:list[int],y_point:list[int],z_point:list[int])->int:
if N == 1:
return 0
n = N - 1
value = 0
x_gap = gap_ls(x_point)
y_gap = gap_ls(y_point)
z_gap = gap_ls(z_point)
xyz = [x_gap,y_gap,z_gap]
index = [0,0,0]
ls = [xyz[i][0] for i in range(3)]
while True:
T = find_min_length(ls)
union(ls[T][1],ls[T][2])
value += ls[T][0]
for i in range(3):
while True:
a,b = xyz[i][index[i]][1],xyz[i][index[i]][2]
if find_root(a) == find_root(b):
index[i] += 1
if index[i] == n:
return value
else:
break
ls[i] = xyz[i][index[i]]
N = int(input())
x_point,y_point,z_point = list(),list(),list()
for i in range(N):
x,y,z = map(int,input().split())
x_point.append((x,i))
y_point.append((y,i))
z_point.append((z,i))
group = [num for num in range(N)]
depth = [0]*(N)
ans = solution(N,x_point,y_point,z_point)
print(ans)
'알고리즘 문제 풀이 > Python' 카테고리의 다른 글
python 백준 1516 게임 개발 (0) | 2023.07.11 |
---|---|
python 백준 1461 도서관 (0) | 2023.07.05 |
python 백준 1939 중량제한 (0) | 2023.04.14 |
python 백준 9655 돌 게임 (0) | 2023.03.30 |
python 백준 10986 나머지 합 (0) | 2023.03.21 |