NullNull

[백준 P1197] 최소 스패닝 트리 Java 본문

알고리즘

[백준 P1197] 최소 스패닝 트리 Java

KYBee 2022. 10. 3. 00:20

P1197 최소 스패닝 트리

 

1197번: 최소 스패닝 트리

첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이

www.acmicpc.net

문제

그래프가 주어졌을 때, 그 그래프의 최소 스패닝 트리를 구하는 프로그램을 작성하시오.

최소 스패닝 트리는, 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 그 가중치의 합이 최소인 트리를 말한다.

입력값

첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이 가중치 C인 간선으로 연결되어 있다는 의미이다. C는 음수일 수도 있으며, 절댓값이 1,000,000을 넘지 않는다.

그래프의 정점은 1번부터 V번까지 번호가 매겨져 있고, 임의의 두 정점 사이에 경로가 있다. 최소 스패닝 트리의 가중치가 -2,147,483,648보다 크거나 같고, 2,147,483,647보다 작거나 같은 데이터만 입력으로 주어진다.

3 3
1 2 1
2 3 2
1 3 3

출력값

첫째 줄에 최소 스패닝 트리의 가중치를 출력한다.

3

알고리즘

MST를 구현하는 문제이다. MST를 구하기 위해 사용하는 알고리즘 중 유명한 것은 크루스칼과 프림이다.

크루스칼 알고리즘은 간선을 기준으로 최소 신장 트리를 구한다. 크루스칼은 Union Find 알고리즘을 통해 구현된다. 다음과 같은 순서로 진행된다.

1. 간선의 정보 (간선의 길이, 왼쪽 노드, 오른쪽 노드) 를 간선의 길이 기준으로 정렬한다.
2. 앞에서부터 차례대로 간선 하나를 선택한다.
	2-1. 두 노드가 이미 연결되어 있는지 확인한다.
		2-1-1. 연결되었다면 2번으로 간다.
		2-1-2. 연결되지 않았다면 UNION 하여 연결시키고 가중치를 추가해준다.
	2-2. 몇 개의 간선이 선택됐는지 확인한다.
		2-2-1. N-1개보다 작다면 2번으로 돌아간다.
		2-2-2. N-1개와 같다면 종료한다.

반면에 프림 알고리즘은 노드를 기준으로 최소 신장 트리를 구한다. 인접리스트와 우선순위 큐를 활용하여 구할 수 있다. 다음과 같은 순서로 진행된다.

1. 인접 리스트로 그래프를 저장한다. 각 노드와 그 노드와 인접한 노드들을 저장한다.
2. 하나의 노드를 임의로 선택한다. 간선의 가중치가 짧은 순서대로 힙을 구성한다.
3. 현재 우선순위 큐에서 루트노드를 poll 한다.
	3-1. 해당 노드를 이미 MST를 만들 때 사용했는지 확인한다.
		3-1-1. 방문했다면 다시 3으로 이동한다.
		3-1-2. 방문하지 않았다면 3-2로 이동한다.
	3-2. 해당 노드는 MST를 만들 때 사용할 것이므로 방문처리를 한다.
	3-3. 지금까지 선택된 노드의 개수가 N인지 확인한다.
		3-3-1. N이면 종료한다.
		3-3-2. N보다 작으면 그 노드와 인접한 모든 노드를 우선순위 큐어 넣어준다.

아래 코드는 두 가지 방식을 모두 구현했다.

코드

import java.io.*;
import java.util.*;

public class Main {

    static int V, E;
    static int K, P;

    static PriorityQueue<Integer[]> graph = new PriorityQueue<>((a, b) -> {
        return a[2] - b[2];
    });

    static ArrayList<Integer[]>[] list;

    static int[] parent;
    static boolean[] visited;

    static int primResult;
    static int kruskalResult;

    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());

        V = Integer.parseInt(st.nextToken());
        E = Integer.parseInt(st.nextToken());

        parent = new int[V + 1];
        visited = new boolean[V + 1];
        list = new ArrayList[V + 1];

        for (int i = 1; i <= V ; i++) {
            parent[i] = i;
            list[i] = new ArrayList<>();
        }

        for (int i = 0; i < E; i++) {
            st = new StringTokenizer(br.readLine());
            int from = Integer.parseInt(st.nextToken());
            int to = Integer.parseInt(st.nextToken());
            int weight = Integer.parseInt(st.nextToken());

            graph.add(new Integer[] {from, to, weight});
            list[from].add(new Integer[] {to, weight});
            list[to].add(new Integer[] {from, weight});
        }

        Prim(1);
        System.out.println(primResult);
        //Kruskal();
        //System.out.println(kruskalResult);

    }

    public static void Prim(int n) {
        PriorityQueue<Integer[]> pq = new PriorityQueue<>((a, b) -> {
            return a[1] - b[1];
        });
        pq.add(new Integer[] {n, 0});

        while (!pq.isEmpty()) {

            Integer[] current = pq.poll();

            if (visited[current[0]]) continue;

            visited[current[0]] = true;
            P++;
            primResult += current[1];

            if (P == V) break;
            else {
                for (Integer[] next : list[current[0]]) {
                    pq.add(new Integer[] {next[0], next[1]});
                }
            }
        }
    }

    public static void Kruskal() {
        while (!graph.isEmpty()) {
            Integer[] current = graph.poll();

            if (find(current[0]) != find(current[1])) {
                union(current[0], current[1]);
                K++;
                kruskalResult += current[2];
            }

            if (K == V - 1) break;
        }
    }

    public static void union(int a, int b) {
        int parentA = find(a);
        int parentB = find(b);

        if (parentA < parentB) {
            parent[parentB] = parentA;
        } else {
            parent[parentA] = parentB;
        }
    }

    public static int find(int a) {
        if (a == parent[a]) return a;
        return parent[a] = find(parent[a]);
    }
}
Comments