NullNull

[백준 P2887] 행성 터널 Java 본문

알고리즘

[백준 P2887] 행성 터널 Java

KYBee 2022. 10. 3. 15:39

P2887 행성 터널

 

2887번: 행성 터널

첫째 줄에 행성의 개수 N이 주어진다. (1 ≤ N ≤ 100,000) 다음 N개 줄에는 각 행성의 x, y, z좌표가 주어진다. 좌표는 -109보다 크거나 같고, 109보다 작거나 같은 정수이다. 한 위치에 행성이 두 개 이

www.acmicpc.net

문제

때는 2040년, 이민혁은 우주에 자신만의 왕국을 만들었다. 왕국은 N개의 행성으로 이루어져 있다. 민혁이는 이 행성을 효율적으로 지배하기 위해서 행성을 연결하는 터널을 만들려고 한다.

행성은 3차원 좌표위의 한 점으로 생각하면 된다. 두 행성 A(xA, yA, zA)와 B(xB, yB, zB)를 터널로 연결할 때 드는 비용은 min(|xA-xB|, |yA-yB|, |zA-zB|)이다.

민혁이는 터널을 총 N-1개 건설해서 모든 행성이 서로 연결되게 하려고 한다. 이때, 모든 행성을 터널로 연결하는데 필요한 최소 비용을 구하는 프로그램을 작성하시오.

입력값

첫째 줄에 행성의 개수 N이 주어진다. (1 ≤ N ≤ 100,000) 다음 N개 줄에는 각 행성의 x, y, z좌표가 주어진다. 좌표는 -1,000,000,000보다 크거나 같고, 1,000,000,000보다 작거나 같은 정수이다. 한 위치에 행성이 두 개 이상 있는 경우는 없다.

5
11 -15 -15
14 -5 -15
-1 -1 -5
10 -4 -1
19 -4 19

출력값

첫째 줄에 모든 행성을 터널로 연결하는데 필요한 최소 비용을 출력한다.

4

알고리즘

모든 행성을 터널로 연결하는데 필요한 최소 비용을 구하는 문제이다.

모든 행성을 연결한다는 점과 최소 비용을 구해야 한다는 점에서 최소 신장 트리를 사용하는 문제구나 라는 것을 알 수 있다. 최소 신장 트리는 2개의 알고리즘을 대표적으로 사용하고 각 알고리즘의 특징은 다음과 같다. 크루스칼과 프림 알고리즘의 구현은 다음을 참고하자.

 

[백준 P1197] 최소 스패닝 트리 Java

P1197 최소 스패닝 트리 1197번: 최소 스패닝 트리 첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수

nullnull.tistory.com

  • 크루스칼 : union-find 알고리즘 사용, 간선을 하나씩 탐색하며 MST 구성 (간선 기준)
  • 프림 : 우선순위 큐 자료구조 사용, 정점을 하나씩 탐색하며 MST 구성 (정점 기준)

이 문제는 크루스칼 알고리즘을 사용해야 한다. 사실 프림 알고리즘도 적절한 가지치기와 초기화를 진행한다면 가능할 수도 있다. 하지만 N의 크기를 생각했을 때 크루스칼 알고리즘을 사용하여 간선 기준으로 MST를 찾는 것이 더 효과적일 것이라 판단했다.

문제와 입력값을 통해 알 수 있는 정보는 지금까지 총 몇 개의 정점이 있고, 각 정점이 위치한 좌표이다. 두 정점을 연결하는 간선이 존재한다고 가정하면 그 값은 두 행성의 X, Y, Z 좌표값 차이 중 가장 작은 값이 된다.

 

따라서 우선 모든 행성들에 대해서 X, Y, Z 좌표 값들의 차이를 구해둔 이후에 값이 작은 간선들을 하나씩 뽑아서, 이 둘이 연결된 경우에는 다음 간선으로 넘어가고, 연결되지 않았다면 서로 연결해주는 방법으로 문제를 구현했다.

 

이 문제의 핵심은 모든 행성들에 대해 X, Y, Z 좌표 값들의 차이를 구하는 것이다. 이를 위해 3개의 2차원 배열을 선언했다. 각각 몇 번째 행성인지와 그 행성의 X, Y, Z의 좌표를 저장한다. 완성한 이후에는 각 배열을 좌표 오름차순으로 정렬한다.

static long[][] X;
static long[][] Y;
static long[][] Z;

for (int i = 0; i < N; i++) {
    st = new StringTokenizer(br.readLine());
    X[i][0] = i; X[i][1] = Long.parseLong(st.nextToken());
    Y[i][0] = i; Y[i][1] = Long.parseLong(st.nextToken());
    Z[i][0] = i; Z[i][1] = Long.parseLong(st.nextToken());
}

//X, Y, Z
Arrays.sort(X, new Comparator<long[]>() {
    @Override
    public int compare(long[] o1, long[] o2) {
        return Long.compare(o1[1], o2[1]);
    }
});

Arrays.sort(Y, new Comparator<long[]>() {
    @Override
    public int compare(long[] o1, long[] o2) {
        return Long.compare(o1[1], o2[1]);
    }
});

Arrays.sort(Z, new Comparator<long[]>() {
    @Override
    public int compare(long[] o1, long[] o2) {
        return Long.compare(o1[1], o2[1]);
    }
});

MST는 모든 정점이 서로 연결되지만 정점을 연결하는 간선이 항상 최소가 되야 한다는 점을 생각해보자. 이 특징을 기억하고 아래의 예시를 살펴보자.

1: [1, 3, 5]
2: [4, 4, 10]
3: [2, 1, 1] 

세 행성의 정보를 위의 X, Y, Z 배열에 저장하면 다음과 같을 것이다.

X = [[1, 1], [3, 2], [2, 4]]
Y = [[3, 1], [1, 3], [2, 4]]
Z = [[3, 1], [1, 5], [2, 10]]

여기서 X를 살펴보겠다. X 배열에 있는 정보들을 통해 만들 수 있는 간선의 수는 다음과 같다.

  • 1과 2를 연결한 간선 : 2
  • 1과 3을 연결한 간선 : 3
  • 2와 3을 연결한 간선 : 2

X만 사용하여 MST를 구성하려면 정렬된 X배열의 i, i + 1 번째 원소를 차례대로 연결해야 할 것이다. 위의 예시에서는 1, 2를 연결한 간선 2와 2, 3을 연결한 간선 2를 사용할 경우이다. 물론 1, 2를 연결하고 1, 3을 연결할 수도 있지만 이 경우에는 MST를 만족하지 않는다. 모든 X, Y, Z배열을 좌표 순으로 정렬하는 이유가 바로 이 것이다. MST가 될 수 있는 간선의 후보군을 뽑아내기 위해서이다.

 

그렇게 정렬한 X, Y, Z 배열을 N - 1번 순회하면서 i 와 i + 1의 차이와 두 정점 정보를 PriorityQueue에 넣어준다. 우선순위 큐에는 MST가 될 수 있는 간선들의 후보군이 저장되며 두 정점과 사이의 거리 정보가 포함된다.

 

우선순위 큐에서 하나씩 간선의 정보를 poll한 뒤, 두 정점이 이미 이전에 다른 간선에 의해 연결되었는지 판단한다. 만약 연결되지 않았다면 현재 간선이 그 두 정점을 연결하기 위한 최소 값을 가지는 간선이므로 둘을 연결하고 연결된 정점의 수를 하나 늘려준다. 만약 연결한 간선이 N - 1개라면 MST가 구성되었다고 판단하고 코드 실행을 종료한다.

int node = 0;
PriorityQueue<long[]> pq = new PriorityQueue<>((a, b) -> {
    return Long.compare(a[2], b[2]);
});

for (int i = 0; i < N - 1; i++) {
    // X Y Z 3개 넣기
    pq.add(new long[] {X[i][0], X[i + 1][0], X[i + 1][1] - X[i][1]});
    pq.add(new long[] {Y[i][0], Y[i + 1][0], Y[i + 1][1] - Y[i][1]});
    pq.add(new long[] {Z[i][0], Z[i + 1][0], Z[i + 1][1] - Z[i][1]});
}

while (node < N - 1) {
    // 가장 작은거 뽑고 MST
    long[] current = pq.poll();

    int a = (int) current[0];
    int b = (int) current[1];
    long weight = current[2];

    if (find(a) != find(b)) {
        node++;
        union(a, b);
        answer += weight;
    }
}

전체 코드는 다음과 같다.

코드

import java.util.*;
import java.io.*;

public class Main {

    static int N;
    static long[][] X;
    static long[][] Y;
    static long[][] Z;

    static int[] parent;
    static long answer;

    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;

        N = Integer.parseInt(br.readLine());
        X = new long[N][2];
        Y = new long[N][2];
        Z = new long[N][2];

        parent = new int[N];
        for (int i = 1; i < N; i++) parent[i] = i;

        for (int i = 0; i < N; i++) {
            st = new StringTokenizer(br.readLine());
            X[i][0] = i; X[i][1] = Long.parseLong(st.nextToken());
            Y[i][0] = i; Y[i][1] = Long.parseLong(st.nextToken());
            Z[i][0] = i; Z[i][1] = Long.parseLong(st.nextToken());
        }

        //X, Y, Z
        Arrays.sort(X, new Comparator<long[]>() {
            @Override
            public int compare(long[] o1, long[] o2) {
                return Long.compare(o1[1], o2[1]);
            }
        });

        Arrays.sort(Y, new Comparator<long[]>() {
            @Override
            public int compare(long[] o1, long[] o2) {
                return Long.compare(o1[1], o2[1]);
            }
        });

        Arrays.sort(Z, new Comparator<long[]>() {
            @Override
            public int compare(long[] o1, long[] o2) {
                return Long.compare(o1[1], o2[1]);
            }
        });

        int node = 0;
        PriorityQueue<long[]> pq = new PriorityQueue<>((a, b) -> {
            return Long.compare(a[2], b[2]);
        });

        for (int i = 0; i < N - 1; i++) {
            // X Y Z 3개 넣기
            pq.add(new long[] {X[i][0], X[i + 1][0], X[i + 1][1] - X[i][1]});
            pq.add(new long[] {Y[i][0], Y[i + 1][0], Y[i + 1][1] - Y[i][1]});
            pq.add(new long[] {Z[i][0], Z[i + 1][0], Z[i + 1][1] - Z[i][1]});
        }

        while (node < N - 1) {
            // 가장 작은거 뽑고 MST
            long[] current = pq.poll();

            int a = (int) current[0];
            int b = (int) current[1];
            long weight = current[2];

            if (find(a) != find(b)) {
                node++;
                union(a, b);
                answer += weight;
            }
        }

        System.out.println(answer);
    }

    public static void union(int a, int b) {
        int parentA = find(a);
        int parentB = find(b);

        if (parentA < parentB) {
            parent[parentB] = parentA;
        } else if (parentA > parentB) {
            parent[parentA] = parentB;
        }
    }

    public static int find(int target) {
        if (target == parent[target]) return target;
        else return parent[target] = find(parent[target]);
    }
}
Comments