본문 바로가기
알고리즘, 문제해결/알고리즘, 자료구조

세그먼트 트리

by 카펀 2021. 9. 25.

다음과 같은 문제를 생각해 봅시다.

 

배열 arr에는 수많은 int형 정수가 들어 있다. arr의 크기 n은 최대 10만이며, 각 정수는 1000 이하의 크기를 가진다.

배열이므로, 당연히 arr의 모든 원소는 인덱스가 정해져 있다. 인덱스는 0부터 시작한다.

총 m줄에 걸쳐 구간이 주어질 때, 주어진 구간합을 출력하는 코드를 작성하시오. (n <= 10만).

 

구간이 주어진다는 말은, 예를 들어 구간이 3, 6으로 주어진다면, 인덱스 3부터 6까지 모든 원소의 합을 출력하는 것을 의미합니다.

위 문제를 단순하게 해결해 봅시다.

 

vector<int> arr = {1, 2, 9, 8, 4, 5, 3, 7};
vector<pair<int, int> > ranges = {{3, 4}, {1, 5}, {2, 6}, {0, 7}};

for (auto& range : ranges) {
	long long answer = 0;
    	for (int i = range.first; i <= range.second; i++) {
    	answer += arr[i];
    }
    cout << answer << '\n';
}

 

배열 arr는 총 8개의 원소를 가지고 있으며, 인덱스 번호를 보기 쉽게 표현해 보면 아래와 같습니다.

 

인덱스 번호 0 1 2 3 4 5 6 7
원소 1 2 9 8 4 5 3 7

 

주어진 구간은 [3, 4], [1, 5], [2, 6], [0, 7] 입니다. 

즉 위 문제에서는 n = 8, m = 4가 됩니다.

 

위 코드대로 문제를 해결한다면 다음과 같은 과정을 수행합니다.

  • 주어진 구간에 대해 배열에 각각 접근하고, 합연산을 수행합니다. 최악의 경우 배열 내의 원소를 n번 접근해야 합니다.
  • 각 구간에 대해 합연산을 구한 후, 출력합니다. 최악의 경우 m개의 구간합을 구해야 합니다.

따라서 위 코드의 시간 복잡도는 O(nm)이 되며, n과 모두 10만 이하의 자연수이므로, 시간복잡도를 줄일 방법을 고려해야 합니다.

실제로 위와 같은 방법으로 구간 합 구하기 (문제)를 풀면, 시간 초과에 걸립니다.

 

이런 상황에서 쓸 수 있는 효과적인 알고리즘이 있습니다.

세그먼트 트리라고 하는데, 간단히 말하면 특정 구간의 합을 미리 구해둔 후, 요청이 있을 때 이미 구한 합을 활용하여 답을 구하는 것입니다.

 

다음과 같은 이진 트리를 생성할 수 있습니다.

구간 합을 구한 이진 트리

숫자 밑의 아랫첨자는 무시해 주세요.

위 트리가 의미하는 것은 무엇일까요?

 

좌: 각 노드가 포함하고 있는 구간, 우: 세그먼트 트리의 노드 번호

세그먼트 트리는 한 번에 세 가지 그림을 놓고 보면 이해가 쉽습니다.

맨 처음 본 트리는 정해진 각 구간의 합입니다.

두 번째 트리는 각 구간을 나타내고 있는데, 루트는 모든 구간을 의미하며, 한 노드의 자식들은 부모 노드의 구간을 반으로 나눈 구간을 포함합니다. 부모 노드의 구간이 오직 하나인 경우, 자식 노드는 존재하지 않습니다.

세 번째 트리는 세그먼트 트리의 노드 번호입니다. 0이 아닌 1부터 시작함을 주목해 주세요. 부모 노드 node의 자식 노드는 node*2, node*2 + 1로 쉽게 나타내기 위해 루트 노드는 1번으로 정합니다.

 

사용하지 않는 0번을 포함하여, 세그먼트 트리의 크기는 "기존 배열의 크기 n보다 큰 최소의 제곱수 * 2" 가 됩니다.

무슨 뜻인지 이해가 어렵다면, 이를 구하는 코드를 참고해 주세요.

 

int closest_square (int n) {
    //n == array's size
    
    int i = 1;
    while(1) {
        if (pow(i, 2) >= n) break;
        i++;
    }
    return pow(i,2) * 2;
}

 

이번 경우에는 n = 8이므로, n보다 큰 가장 작은 제곱수는 9가 될 것이고, 따라서 세그먼트 트리의 크기는 18이 됩니다.

실제로는 데이터의 개수 n * 4 만큼 미리 세그먼트 트리 영역을 할당하면 됩니다.

 

그렇다면 세그먼트 트리를 어떻게 만들 수 있을까요?

우리가 주목할 점은, "자식 노드 두 개는 부모 노드의 구간을 반으로 나눈 구간을 가진다"는 점입니다.

따라서 구간이 단 하나만 될 때까지 구간을 반으로 나누어 가며 재귀적으로 접근할 수 있습니다.

 

vector<int> array는 문제에서 주어진 배열, vector<int> tree는 만들어 둔 빈 세그먼트 트리 배열이라고 합시다.

 

int segment_tree (vector<int> &array, vector<int> &tree, int start, int end, int node) {
    if (start == end) return tree[node] = array[start];
    int mid = (start + end) / 2;
    
    return tree[node] = segment_tree(array, tree, start, mid, node*2) + segment_tree(array, tree, mid+1, end, node*2 + 1);
}

 

start == end인 경우, 구간이 단 하나인 경우입니다. 이 때는 노드에 해당 구간의 값을 적어 줍니다.

그렇지 않은 경우, 구간의 중간점 mid를 구하고, tree[node]에는 재귀적으로 함수를 두 번 호출한 값을 기록합니다.

구간은 [start, mid]와 [mid+1, end]로 나뉘고, 노드는 각각 node*2, node*2 + 1가 됩니다.

이 내용은 종이에 직접 트리를 그려 가며 코드를 따라가 보면 이해가 쉽습니다.

세그먼트 트리의 생성에는 O (N log N)의 시간 및 공간 복잡도가 요구됩니다.

 

그러면 위 과정을 통해 만든 세그먼트 트리를 어떻게 활용할 수 있을까요?

예를 들어 인덱스 3~7의 부분합을 구한다고 가정합시다.

우리가 가진 세그먼트 트리를 다시 보면, 맨 처음 구간은 0~3, 4~7로 나뉩니다.

맨 처음에는 구간 0~7을 보면, 우리가 구하려는 3~7과 일치하지 않으며, 우리가 구하려는 구간은 0~7 내에 존재합니다.

따라서 구간을 반으로 나누어 0~3, 4~7을 찾아봅니다.

 

우리는 인덱스 3의 값과 4~7의 값을 구한 후 더하면 됩니다.

4~7은 바로 밑에 존재하므로 해당 tree의 값을 return해 주면 됩니다.

3의 경우 구간 0~3 내에 존재하므로, 구간 0~1, 2~3을 나누어 탐색합니다.

구간 0~1 내에는 존재하지 않으므로 0을 return해 주고, 2~3 내에는 존재하므로 구간을 2와 3으로 나눕니다.

구간 2 내에는 구간 3이 존재하지 않으므로 0을 return해 주고, 구간 3은 구간 3을 완벽히 포함하므로 해당 tree의 값을 return해 줍니다.

 

코드로 나타내면 아래와 같습니다.

 

int get_sum(int start, int end, int left, int right, int node, vector<int> &tree) {
    if (left > end || right < start) return 0;
    if (left <= start && end <= right) return tree[node];
    int mid = (start + end) / 2;
    return get_sum(start, mid, left, right, node*2, tree) + get_sum(mid+1, end, left, right, node*2 + 1, tree);
    
}

 

따라서 구간 합을 O(log n)의 시간복잡도 이내에 구할 수 있습니다.

 

마지막으로, 특정한 인덱스 i의 값을 수정할 때를 생각해 봅시다.

인덱스 5의 값을 -3만큼 수정한다면, 보통 배열에서는 해당 인덱스에 접근하여 -3만큼 값을 수정하면 끝납니다. O(1)의 시간복잡도로 마칠 수 있죠.

하지만 우리는 구간합을 미리 구해둔 세그먼트 트리를 이용하고 있습니다. 따라서 값을 수정하는 경우, 해당 인덱스를 포함하는 모든 트리 노드를 수정해야 합니다.

 

루트 노드는 구간 0~7의 합을 가지고 있고, 인덱스 5를 포함합니다. 따라서 루트 노드의 값에 -3만큼 변경을 가합니다.

이 변화는 자식 노드들에게 전파 (propagation)되어야 합니다. 따라서 구간을 마찬가지로 반으로 나눕니다.

구간 0~3은 인덱스 5를 포함하지 않습니다. 따라서 수정 없이 마칩니다.

구간 4~7은 인덱스 5를 포함합니다. 따라서 해당 노드의 값에 -3만큼 변경을 가합니다.

... 위와 같은 과정을 반복하면,

인덱스 0~7, 4~7, 4~5, 5의 부분합을 포함하는 노드에 -3만큼 변경을 가해야 합니다.

따라서 한 원소의 값을 수정하는데 O(log n)의 시간복잡도를 가집니다.

 

이 내용 역시 코드로 나타내 보겠습니다.

 

void edit_value (int start, int end, int node, int index, int value, vector<int> &tree) {
    if (index < start || index > end) return;
    tree[node] += value;
    if (start == end) return;
    int mid = (start + end) / 2;
    edit_value(start, mid, node*2, index, value, tree);
    edit_value(mid+1, end, node*2+1, index, value, tree);
}

 

이를 이용해 앞에서 언급한 문제를 풀 수 있습니다 (링크).

 

위 내용을 모두 종합한 코드는 다음과 같습니다.

#include <iostream>
#include <vector>
#include <cmath>
using namespace std;

int closest_square (vector<int> &array) {
    int n = (int)array.size();
    
    int i = 1;
    while(1) {
        if (pow(i, 2) >= n) break;
        i++;
    }
    return pow(i,2) * 2;
}

int segment_tree (vector<int> &array, vector<int> &tree, int start, int end, int node) {
    if (start == end) return tree[node] = array[start];
    int mid = (start + end) / 2;
    
    return tree[node] = segment_tree(array, tree, start, mid, node*2) + segment_tree(array, tree, mid+1, end, node*2 + 1);
}

int get_sum(int start, int end, int left, int right, int node, vector<int> &tree) {
    if (left > end || right < start) return 0;
    if (left <= start && end <= right) return tree[node];
    int mid = (start + end) / 2;
    return get_sum(start, mid, left, right, node*2, tree) + get_sum(mid+1, end, left, right, node*2 + 1, tree);
    
}

void edit_value (int start, int end, int node, int index, int value, vector<int> &tree) {
    if (index < start || index > end) return;
    tree[node] += value;
    if (start == end) return;
    int mid = (start + end) / 2;
    edit_value(start, mid, node*2, index, value, tree);
    edit_value(mid+1, end, node*2+1, index, value, tree);
}

int main() {
    ios::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);
    vector<int> array = {1, 2, 9, 8, 4, 5, 3, 7};
    int n = closest_square(array);
    int size = (int)array.size() - 1;
    vector<int> tree(n);
    
    segment_tree(array, tree, 0, size, 1);
    
    cout << "3~7 구간합: " << get_sum(0, size, 3, 7, 1, tree) << endl;
    cout << "인덱스 5의 값을 -3만큼 수정\n";
    edit_value(0, size, 1, 5, -3, tree);
    cout << "3~7 구간합: " << get_sum(0, size, 3, 7, 1, tree) << endl;
}

 

 

내용 참고:

트리 그리기: http://ironcreek.net/syntaxtree/

댓글