본문 바로가기
computer science/알고리즘

[알고리즘] 최소비용 알고리즘(Kruskal's, Prim's)

by 박연호의 개발 블로그 2020. 9. 18.

최소신장트리(MST : Minimum Spanning Tree)

 

크루스칼 알고리즘을 알기 전에 먼저 최소신장트리에 대해 간략하게 공부하고 넘어 가겠습니다.

 

신장트리란 트리의 특수한 형태로 모든 정점을 포함하고, 정점간 서로 연결되면서 사이클이 존재하지 않는 그래프를 의미합니다. 

 

 

 

 

여기서 최소신장트리는 신장트리들 중에서 간선의 가중치가 최소를 만족하는 신장트리를 의미합니다.

아래의 사진을 보면 여러 정점들 사이에 간선이 존재하고, 이 간선에는 가중치값이 존재합니다. 아래의 왼쪽 그래프에는 다양한 신장트리가 나올 수 있습니다. 하지만 여러개의 신장트리 중에서 그 가중치의 합(간선에 배정된 값)이 최소가 되는 트리가 최소신장트리 입니다.

 

여기서 최소신장트리를 만드는 알고리즘이 이번시간에 배울 크루스칼 알고리즘과 프림 알고리즘 입니다.


Kruskal's Algorithm

 

크루스칼 알고리즘은 사이클을 만들지 않으면서 가중치가 가장 작은 간선을 하나씩 선택하는 방법입니다.

 

1. 최소간선을 선택한다(작은값 부터).

2. 간선을 선택했을 때 사이클이 생성되는지 확인한다.

3. 생성되지 않으면 간선을 선택하고 1번부터 반복, 생성되면 다른 간선 선택한다.

4. 간선의 개수가 노드개수 -1 이될때까지 반복한다.

 

#include <iostream>
#include <algorithm>
#include <vector>

using namespace std;

// 부모 노드를 가져옴
int getParent(int set[], int x)
{
    if (set[x] == x)
    {
        return x;
    }
    return set[x] = getParent(set, set[x]);
}

// 부모 노드를 병합
void unionParent(int set[], int a, int b)
{
    a = getParent(set, a);
    b = getParent(set, b);
    if (a < b)
    {
        set[b] = a;
    }
    else
    {
        set[a] = b;
    }
}

// 같은 부모를 가지는지 확인
int find(int set[], int a, int b)
{
    a = getParent(set, a);
    b = getParent(set, b);

    if (a == b)
    {
        return 1;
    }
    else
    {
        return 0;
    }
}

class Edge
{
public:
    int node[2];
    int distance;
    Edge(int a, int b, int distance)
    {
        this->node[0] = a;
        this->node[1] = b;
        this->distance = distance;
    }
    bool operator<(Edge &edge)
    {
        return this->distance < edge.distance;
    }
};

int main(void)
{
    int n = 7;
    int m = 11;

    vector<Edge> v;

    v.push_back(Edge(1, 7, 12));
    v.push_back(Edge(1, 4, 28));
    v.push_back(Edge(1, 2, 67));
    v.push_back(Edge(1, 5, 17));
    v.push_back(Edge(2, 4, 24));
    v.push_back(Edge(2, 5, 62));
    v.push_back(Edge(3, 5, 20));
    v.push_back(Edge(3, 6, 137));
    v.push_back(Edge(4, 7, 13));
    v.push_back(Edge(5, 6, 45));
    v.push_back(Edge(5, 7, 73));

    // 간선의 비용으로 오름차순 정렬
    sort(v.begin(), v.end());

    // 각 정점이 포함된 그래프가 어디인지 저장
    int set[n];
    for (int i = 0; i < n; i++)
    {
        set[i] = i;
    }

    int sum = 0;
    for (int i = 0; i < v.size(); i++)
    {
        // 동일한 부모를 가르키지 않는 경우, 즉 사이클이 발생하지 않을 때만 선택
        if (!find(set, v[i].node[0] - 1, v[i].node[1] - 1))
        {
            sum += v[i].distance;
            unionParent(set, v[i].node[0] - 1, v[i].node[1] - 1);
        }
    }

    cout << sum << endl;
    return 0;
}

Prim's Algorithm

프림 알고리즘은 사이클을 만들지 않으면서 노드와 연결된 간선중 가중치가 가장 작은 간선을 하나씩 선택하는 방법입니다. 크루스칼 알고리즘은 노드에 연결된 간선에 상관없이 간선을 선택하지만, 프림 알고리즘은 노드에 연결된 간선중에 가중치가 가장 적은 간선을 선택합니다.

 

1. 최소간선을 선택한다(작은값 부터).

2. 간선을 선택했을 때 사이클이 생성되는지 확인한다.

3. 생성되지 않으면 간선을 선택하고 1번부터 반복, 생성되면 다른 간선 선택한다.

4. 간선의 개수가 노드개수 -1 이될때까지 반복한다.

 

prim알고리즘에서는 노드집합의 연결된 간선중에서 가중치가 가장 작은 간선을 선택해야 하기 때문에 우선순위큐를 사용합니다.

#include <iostream>
#include <queue>
#include <vector>
using namespace std;

struct Edge
{
    int start;
    int end;
    int cost;
    Edge() : start(0), end(0), cost(0)
    {
    }
    Edge(int start, int end, int cost) : start(start), end(end), cost(cost)
    {
    }
    bool operator<(const Edge &edge) const
    {
        return cost > edge.cost;
    }
};

vector<pair<int, int>> a[1001];
bool c[1001];

int main()
{
    int n, m;
    cin >> n >> m;
    for (int i = 0; i < m; i++)
    {
        int start, end, cost;
        cin >> start >> end >> cost;
        a[start].push_back(make_pair(end, cost));
        a[end].push_back(make_pair(start, cost));
    }
    c[1] = true;
    priority_queue<Edge> q;
    // 시작 노드를 1로 하고 q 에 1 과 연결된 노드들을 넣어줌
    for (int i = 0; i < a[1].size(); i++)
    {
        q.push(Edge(1, a[1][i].first, a[1][i].second));
    }
    int ans = 0;
    for (int i = 0; i < n - 1; i++)
    {
        Edge e;
        // 우선순위 큐이기 때문에 top 의 값은 가장 작은 값
        while (!q.empty())
        {
            e = q.top();
            q.pop();
            if (c[e.end] == false)
            {
                break;
            }
        }
        c[e.end] = true;
        ans += e.cost;
        int x = e.end;
        for (int i = 0; i < a[x].size(); i++)
        {
            q.push(Edge(x, a[x][i].first, a[x][i].second));
        }
    }
    cout << ans << "\n";
    return 0;
}


// 입력
6
9
1 2 5
1 3 4
2 3 2
2 4 7
3 4 6
3 5 11
4 5 3
4 6 8
5 6 8