다음과 같은 문제를 생각해 봅시다.
배열 arr에는 수많은 int형 정수가 들어 있다. arr의 크기 n은 최대 10만이며, 각 정수는 1000 이하의 크기를 가진다.
배열이므로, 당연히 arr의 모든 원소는 인덱스가 정해져 있다. 인덱스는 0부터 시작한다.
총 m줄에 걸쳐 구간이 주어질 때, 주어진 구간합을 출력하는 코드를 작성하시오. (n <= 10만).
구간이 주어진다는 말은, 예를 들어 구간이 3, 6으로 주어진다면, 인덱스 3부터 6까지 모든 원소의 합을 출력하는 것을 의미합니다.
위 문제를 단순하게 해결해 봅시다.
vector<int> arr = {1, 2, 9, 8, 4, 5, 3, 7};
vector<pair<int, int> > ranges = {{3, 4}, {1, 5}, {2, 6}, {0, 7}};
for (auto& range : ranges) {
long long answer = 0;
for (int i = range.first; i <= range.second; i++) {
answer += arr[i];
}
cout << answer << '\n';
}
배열 arr는 총 8개의 원소를 가지고 있으며, 인덱스 번호를 보기 쉽게 표현해 보면 아래와 같습니다.
인덱스 번호 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
원소 | 1 | 2 | 9 | 8 | 4 | 5 | 3 | 7 |
주어진 구간은 [3, 4], [1, 5], [2, 6], [0, 7] 입니다.
즉 위 문제에서는 n = 8, m = 4가 됩니다.
위 코드대로 문제를 해결한다면 다음과 같은 과정을 수행합니다.
- 주어진 구간에 대해 배열에 각각 접근하고, 합연산을 수행합니다. 최악의 경우 배열 내의 원소를 n번 접근해야 합니다.
- 각 구간에 대해 합연산을 구한 후, 출력합니다. 최악의 경우 m개의 구간합을 구해야 합니다.
따라서 위 코드의 시간 복잡도는 O(nm)이 되며, n과 모두 10만 이하의 자연수이므로, 시간복잡도를 줄일 방법을 고려해야 합니다.
실제로 위와 같은 방법으로 구간 합 구하기 (문제)를 풀면, 시간 초과에 걸립니다.
이런 상황에서 쓸 수 있는 효과적인 알고리즘이 있습니다.
세그먼트 트리라고 하는데, 간단히 말하면 특정 구간의 합을 미리 구해둔 후, 요청이 있을 때 이미 구한 합을 활용하여 답을 구하는 것입니다.
다음과 같은 이진 트리를 생성할 수 있습니다.
숫자 밑의 아랫첨자는 무시해 주세요.
위 트리가 의미하는 것은 무엇일까요?
세그먼트 트리는 한 번에 세 가지 그림을 놓고 보면 이해가 쉽습니다.
맨 처음 본 트리는 정해진 각 구간의 합입니다.
두 번째 트리는 각 구간을 나타내고 있는데, 루트는 모든 구간을 의미하며, 한 노드의 자식들은 부모 노드의 구간을 반으로 나눈 구간을 포함합니다. 부모 노드의 구간이 오직 하나인 경우, 자식 노드는 존재하지 않습니다.
세 번째 트리는 세그먼트 트리의 노드 번호입니다. 0이 아닌 1부터 시작함을 주목해 주세요. 부모 노드 node의 자식 노드는 node*2, node*2 + 1로 쉽게 나타내기 위해 루트 노드는 1번으로 정합니다.
사용하지 않는 0번을 포함하여, 세그먼트 트리의 크기는 "기존 배열의 크기 n보다 큰 최소의 제곱수 * 2" 가 됩니다.
무슨 뜻인지 이해가 어렵다면, 이를 구하는 코드를 참고해 주세요.
int closest_square (int n) {
//n == array's size
int i = 1;
while(1) {
if (pow(i, 2) >= n) break;
i++;
}
return pow(i,2) * 2;
}
이번 경우에는 n = 8이므로, n보다 큰 가장 작은 제곱수는 9가 될 것이고, 따라서 세그먼트 트리의 크기는 18이 됩니다.
실제로는 데이터의 개수 n * 4 만큼 미리 세그먼트 트리 영역을 할당하면 됩니다.
그렇다면 세그먼트 트리를 어떻게 만들 수 있을까요?
우리가 주목할 점은, "자식 노드 두 개는 부모 노드의 구간을 반으로 나눈 구간을 가진다"는 점입니다.
따라서 구간이 단 하나만 될 때까지 구간을 반으로 나누어 가며 재귀적으로 접근할 수 있습니다.
vector<int> array는 문제에서 주어진 배열, vector<int> tree는 만들어 둔 빈 세그먼트 트리 배열이라고 합시다.
int segment_tree (vector<int> &array, vector<int> &tree, int start, int end, int node) {
if (start == end) return tree[node] = array[start];
int mid = (start + end) / 2;
return tree[node] = segment_tree(array, tree, start, mid, node*2) + segment_tree(array, tree, mid+1, end, node*2 + 1);
}
start == end인 경우, 구간이 단 하나인 경우입니다. 이 때는 노드에 해당 구간의 값을 적어 줍니다.
그렇지 않은 경우, 구간의 중간점 mid를 구하고, tree[node]에는 재귀적으로 함수를 두 번 호출한 값을 기록합니다.
구간은 [start, mid]와 [mid+1, end]로 나뉘고, 노드는 각각 node*2, node*2 + 1가 됩니다.
이 내용은 종이에 직접 트리를 그려 가며 코드를 따라가 보면 이해가 쉽습니다.
세그먼트 트리의 생성에는 O (N log N)의 시간 및 공간 복잡도가 요구됩니다.
그러면 위 과정을 통해 만든 세그먼트 트리를 어떻게 활용할 수 있을까요?
예를 들어 인덱스 3~7의 부분합을 구한다고 가정합시다.
우리가 가진 세그먼트 트리를 다시 보면, 맨 처음 구간은 0~3, 4~7로 나뉩니다.
맨 처음에는 구간 0~7을 보면, 우리가 구하려는 3~7과 일치하지 않으며, 우리가 구하려는 구간은 0~7 내에 존재합니다.
따라서 구간을 반으로 나누어 0~3, 4~7을 찾아봅니다.
우리는 인덱스 3의 값과 4~7의 값을 구한 후 더하면 됩니다.
4~7은 바로 밑에 존재하므로 해당 tree의 값을 return해 주면 됩니다.
3의 경우 구간 0~3 내에 존재하므로, 구간 0~1, 2~3을 나누어 탐색합니다.
구간 0~1 내에는 존재하지 않으므로 0을 return해 주고, 2~3 내에는 존재하므로 구간을 2와 3으로 나눕니다.
구간 2 내에는 구간 3이 존재하지 않으므로 0을 return해 주고, 구간 3은 구간 3을 완벽히 포함하므로 해당 tree의 값을 return해 줍니다.
코드로 나타내면 아래와 같습니다.
int get_sum(int start, int end, int left, int right, int node, vector<int> &tree) {
if (left > end || right < start) return 0;
if (left <= start && end <= right) return tree[node];
int mid = (start + end) / 2;
return get_sum(start, mid, left, right, node*2, tree) + get_sum(mid+1, end, left, right, node*2 + 1, tree);
}
따라서 구간 합을 O(log n)의 시간복잡도 이내에 구할 수 있습니다.
마지막으로, 특정한 인덱스 i의 값을 수정할 때를 생각해 봅시다.
인덱스 5의 값을 -3만큼 수정한다면, 보통 배열에서는 해당 인덱스에 접근하여 -3만큼 값을 수정하면 끝납니다. O(1)의 시간복잡도로 마칠 수 있죠.
하지만 우리는 구간합을 미리 구해둔 세그먼트 트리를 이용하고 있습니다. 따라서 값을 수정하는 경우, 해당 인덱스를 포함하는 모든 트리 노드를 수정해야 합니다.
루트 노드는 구간 0~7의 합을 가지고 있고, 인덱스 5를 포함합니다. 따라서 루트 노드의 값에 -3만큼 변경을 가합니다.
이 변화는 자식 노드들에게 전파 (propagation)되어야 합니다. 따라서 구간을 마찬가지로 반으로 나눕니다.
구간 0~3은 인덱스 5를 포함하지 않습니다. 따라서 수정 없이 마칩니다.
구간 4~7은 인덱스 5를 포함합니다. 따라서 해당 노드의 값에 -3만큼 변경을 가합니다.
... 위와 같은 과정을 반복하면,
인덱스 0~7, 4~7, 4~5, 5의 부분합을 포함하는 노드에 -3만큼 변경을 가해야 합니다.
따라서 한 원소의 값을 수정하는데 O(log n)의 시간복잡도를 가집니다.
이 내용 역시 코드로 나타내 보겠습니다.
void edit_value (int start, int end, int node, int index, int value, vector<int> &tree) {
if (index < start || index > end) return;
tree[node] += value;
if (start == end) return;
int mid = (start + end) / 2;
edit_value(start, mid, node*2, index, value, tree);
edit_value(mid+1, end, node*2+1, index, value, tree);
}
이를 이용해 앞에서 언급한 문제를 풀 수 있습니다 (링크).
위 내용을 모두 종합한 코드는 다음과 같습니다.
#include <iostream>
#include <vector>
#include <cmath>
using namespace std;
int closest_square (vector<int> &array) {
int n = (int)array.size();
int i = 1;
while(1) {
if (pow(i, 2) >= n) break;
i++;
}
return pow(i,2) * 2;
}
int segment_tree (vector<int> &array, vector<int> &tree, int start, int end, int node) {
if (start == end) return tree[node] = array[start];
int mid = (start + end) / 2;
return tree[node] = segment_tree(array, tree, start, mid, node*2) + segment_tree(array, tree, mid+1, end, node*2 + 1);
}
int get_sum(int start, int end, int left, int right, int node, vector<int> &tree) {
if (left > end || right < start) return 0;
if (left <= start && end <= right) return tree[node];
int mid = (start + end) / 2;
return get_sum(start, mid, left, right, node*2, tree) + get_sum(mid+1, end, left, right, node*2 + 1, tree);
}
void edit_value (int start, int end, int node, int index, int value, vector<int> &tree) {
if (index < start || index > end) return;
tree[node] += value;
if (start == end) return;
int mid = (start + end) / 2;
edit_value(start, mid, node*2, index, value, tree);
edit_value(mid+1, end, node*2+1, index, value, tree);
}
int main() {
ios::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);
vector<int> array = {1, 2, 9, 8, 4, 5, 3, 7};
int n = closest_square(array);
int size = (int)array.size() - 1;
vector<int> tree(n);
segment_tree(array, tree, 0, size, 1);
cout << "3~7 구간합: " << get_sum(0, size, 3, 7, 1, tree) << endl;
cout << "인덱스 5의 값을 -3만큼 수정\n";
edit_value(0, size, 1, 5, -3, tree);
cout << "3~7 구간합: " << get_sum(0, size, 3, 7, 1, tree) << endl;
}
내용 참고:
트리 그리기: http://ironcreek.net/syntaxtree/
'알고리즘, 문제해결 > 알고리즘, 자료구조' 카테고리의 다른 글
[Java] Java에서 사용하는 여러 자료구조 정리 (0) | 2021.10.18 |
---|---|
느리게 갱신되는 세그먼트 트리 (0) | 2021.09.26 |
C++의 auto에 대해 (0) | 2021.08.31 |
[SQL] String, Date (0) | 2021.03.11 |
[SQL] JOIN (0) | 2021.03.11 |
댓글