세그먼트 트리(Segment Tree)
Updated:
세그먼트 트리란?
여러 개의 데이터가 연속적으로 존재할 때 특정한 범위의 데이터 합을 가장 빠르고 간단하게 구할 수 있는 자료구조이다.
예를들면 길이가 N인 배열 A에서 A[i]부터 A[j]까지의 부분합을 구하고 A[k]=V로 바꾸어라.
- 부분합을 구하는 시간 복잡도 : $O(N)$
- A[k]=V로 바꾸는 시간 복잡도 : $O(1)$ → 쿼리가 M개인 경우 총 시간 복잡도는 $O(MN)$을 가진다.
그러나, 세그먼트 트리를 이용하면 두 쿼리 모두 $O(logN)$의 시간 복잡도를 가지게 된다.
세그먼트 트리 구조
길이 10의 순열을 세그먼트 트리로 구성하면 다음 그림과 같다.
트리를 만드는 방법은 다음과 같다.
- 루트는 전체 순열의 합이 들어간다.
- 자식 노드는 부모의 데이터를 절반씩 나누어 구간 합을 저장한다.
이 과정을 반복하면 구간 합 트리의 전체 노드를 구할 수 있다.
이때, 루트의 번호는 0이 아니라 1을 의미한다.
왜냐하면 다음 왼쪽 자식이 2, 오른쪽 자식이 3을 가리키게 되면서 부모노드 번호에서 2를 곱하면 왼쪽 자식노드를 의미하기 때문에 효과적이다. 아래 그림을 보면 이해하기 쉽다.
또한, 구현 방식은 재귀적으로 구하는 것이 더 간단하다.
// start: 시작 인덱스, end: 끝 인덱스
int init(int start, int end, int node) {
if(start == end) return tree[node] = a[start];
int mid = (start + end) / 2;
// 재귀적으로 두 부분으로 나눈 뒤에 그 합을 자기 자신으로 한다.
return tree[node] = init(start, mid, node * 2) + init(mid + 1, end, node * 2 + 1);
}
구간 합 구하기
위의 그림에서 a[4] ~ a[8]의 구간합을 구하고 싶다면 녹색으로 칠해진 노드의 값만 더해주면 된다.
// start : 시작 인덱스, end : 끝 인덱스
// left, right : 구간 합을 구하고자 하는 범위
int sum(int start, int end, int node, int left, int right) {
// 범위 밖에 있는 경우
if(left > end || right < start) return 0;
// 범위 안에 있는 경우
if(left <= start && end <= right) return tree[node];
// 그렇지 않다면 두 부분으로 나누어 합을 구하기
int mid = (start + end) / 2;
return sum(start, mid, node * 2, left, rigth) + sum(mid + 1, end, node * 2 + 1);
}
특정 원소의 값 바꾸기
특정 원소의 값을 바꾸고 싶다면 해당 원소를 포함하고 있는 모든 구간 합 노드를 갱신해야 한다.
// start : 시작 인덱스, end : 끝 인덱스
// index : 구간 합을 수정하고자 하는 노드
// dif : 수정할 값과 원래의 값의 차이 (val - a[index])
void update(int start, int end, int node, int index, int dif) {
// 범위 밖에 있는 경우
if(index < start || index > end) return;
// 범위 안에 있으면 내려가며 다른 원소도 갱신
tree[node] += dif;
if(start == end) return;
int mid = (start + end) / 2;
update(start, mid, node * 2, index, dif);
update(mid + 1, end, node * 2 + 1, index, dif);
}
전체 코드
#include <iostream>
#include <vector>
#define NUM 13
using namespace std;
int a[]={1,9,3,8,4,5,5,9,10,3,4,5};
int tree[4*NUM];
/*
4를 곱하면 모든 범위를 커버할 수 있다.
갯수에 대해서 2의 제곱 형태의 길이를 가지기 되기 때문
*/
int init(int start, int end, int node) {
if(start == end)
return tree[node] = a[start];
int mid = (start + end) / 2;
return tree[node] = init(start, mid, node * 2) + init(mid + 1, end, node * 2 + 1);
}
int sum(int start, int end, int node, int left, int right) {
if(left > end || right < start) return 0;
if(left <= start && end <= right) return tree[node];
int mid = (start + end) / 2;
return sum(start, mid, node * 2, left, right) + sum(mid + 1, end, node * 2 + 1, left, right);
}
void update(int start, int end, int node, int index, int dif) {
if(index < start || index > end) return;
tree[node] += dif;
if(start == end) return;
int mid = (start + end) / 2;
update(start, mid, node * 2, index, dif);
update(mid + 1, end, node * 2 + 1, index, dif);
}
int main(void) {
init(0, NUM-1, 1);
cout<<"0부터 12까지의 구간 합: "<<sum(0, NUM-1,1,0,12)<<endl;
cout<<"3부터 8까지의 구간 합: "<<sum(0,NUM-1,1,3,8)<<endl;
cout<<"인덱스 5의 원소를 0으로 수정"<<endl;
update(0,NUM-1,1,5,-5); // val - a[index]
cout<<"3부터 8까지의 구간 합: "<<sum(0,NUM-1,1,3,8)<<endl;
return 0;
}
/*
결과:
0부터 12까지의 구간 합 : 66
3부터 8까지의 구간 합 : 41
인덱스 5의 원소를 0으로 수정
3부터 8까지의 구간 합 : 36
*/
Comments