BOBO's Note

Segment Tree와 Indexed Tree 본문

Algorithm

Segment Tree와 Indexed Tree

bobo_hee 2020. 8. 16. 16:51

세그먼트 트리는 어떠한 배열이 주어졌을 때, 각 구간의 대표값(예. 합, 최대값 등)을 빠르게 구할 수 있는 자료구조이다. 배열의 크기가 N일 때, 임의의 구간에 대한 쿼리를 O(logN)에 수행할 수 있다.

 

인덱스 트리는 팬윅 트리(Fenwick Tree)라고도 하며, 세그먼트 트리처럼 각 구간의 대표값을 빠르게 구할 수 있다. 세그먼트 트리보다 구현하기 더 간단하고 속도가 빠르다는 장점이 있다.

 

이 포스터에서는 구간합을 세그먼트 트리를 이용해 빠르게 구하는 방법에 대해 알아보자.

 

구간합 구하기

배열 arr[0], arr[1], ..., arr[N-1]이 있다고 할 때, 임의의 구간 [l, r]에 대하여 구간합 arr[l] + arr[l+1] + ... + arr[r-1] + arr[r]을 구해라.

 

위 문제를 해결하는 가장 간단한 방법은 매번 배열을 순회하는 방법으로, 쿼리를 수행할 때마다 O(N)의 시간이 소요된다. 따라서 k번의 쿼리를 수행한다면 전체 시간복잡도는 O(kN)이다.

for(int i=l; i<=r; i++){
	sum += arr[i];
}

 

쿼리가 많아질수록 시간복잡도는 증가하게 되는데, 이를 개선하는 방법을 생각해보자. arr[0]~arr[i]까지의 누적합을 i번째 원소에 저장하는 배열 acc_sum을 하나 선언하자. acc_sum을 초기화하는 데에 O(N)이 걸리지만, 이후 구간 [l, r]의 누적합은 acc_sum[r]-arr[l-1]로 O(1)의 시간에 알아낼 수 있다. 따라서 k번의 쿼리를 수행한다면 전체 시간복잡도는 O(k)이다.

int acc_sum[N];

// Initialize acc_sum => O(N)
acc_sum[0] = arr[0];
for(int i=1; i<N; i++){
	acc_sum[i] = acc_sum[i-1] + arr[i];
}

// query: segment sum of [l, r] => O(1)
int sum = acc_sum[r] - acc_sum[l-1];

그러나 이 방법은 주어진 배열의 원소 값에 변경이 생긴다면 acc_sum을 다시 구해야 하므로 O(N)의 시간이 발생하고, 이는 성능저하의 요인이 될 수 있다. 

 

세그먼트 트리를 이용하면 (1) 쿼리 수행 및 (2) 배열의 원소 값 변경 모두 O(logN)의 시간에 수행할 수 있다.

 

Segment Tree

세그먼트 트리는 기본적으로 포화 이진트리 형태이다. 주어진 배열의 원소들은 세그먼트 트리의 leaf 노드에 차례대로 저장된다.

Segment Tree - Leaf Node

그리고 내부 노드들은 자식 노드들의 대표값(이 글에서는 구간합)을 저장한다. 

Segment Tree - Internal Node

 

세그먼트 트리 초기화를 마친 후, 구간합에 대한 쿼리를 수행한다. 루트 노드부터 시작해 원하는 구간에 대하여 쿼리를 내려보내고, 이를 종합하여 쿼리의 결과를 계산한다. 각 노드는 자신이 담당하는 구간이 쿼리의 구간에 완전히 속하면 노드 값을 반환하고, 그렇지 않으면 자식 노드들에게 쿼리를 포워드한다.

 

이때, 쿼리의 구간에 완전히 속한다는 것은 다음과 같은 경우이다.

 

예를 들어 구간 [3, 7]의 구간합을 구하는 쿼리는 다음과 같이 수행된다.

 

1. 루트 노드부터 구간을 검사한다. 노드의 구간이 쿼리 구간에 완전히 속하는 게 아니므로 자식 노드에게 쿼리를 포워드한다.

 

2. depth=1인 자식 노드들도 자신의 구간이 쿼리 구간에 완전히 속하는 게 아니므로 자식 노드에게 쿼리를 포워드한다.

 

3. depth=2인 자식노드들 중, 쿼리 구간과 겹치는 구간이 없으면 0을 바로 리턴한다. 만약 쿼리 구간에 완전히 속하면 자신의 값을 리턴한다. 

 

4. depth=4인 노드들도 마찬가지로 완전히 속하면 자신의 값을 리턴한다.

 

 

Implementation

세그먼트 트리를 Java로 구현해보자. 

 

우선 SegmentTree라는 클래스를 선언하고, 주어지는 배열을 저장하는 nums 배열과 세그먼트 트리를 배열 형태로 저장한 배열 tree를 멤버변수로 갖는다. 그리고 세그먼트 트리의 깊이 depth와 leaf 노드의 개수 leafSize도 멤버변수로 갖는다.

class SegmentTree {
	int[] nums;
	int[] tree;
	int depth;
	int leafSize;

	/* to-be-continue */
}

 

주어진 배열로 세그먼트 트리를 초기화하는 코드는 다음과 같다. leaf 노드라면 배열 원소 값을 저장하고 반환한다. 내부 노드라면 자식 노드에 대하여 makeTree()를 재귀적으로 호출하고, 이를 합하여 노드에 저장 후 반환한다.

makeTree(1, 1, nums.length);

// nodeIndex: 세그먼트 트리에서 해당 노드의 인덱스
// left: 세그먼트 트리에서 해당 노드의 구간 시작 값
// right: 세그먼트 트리에서 해당 노드의 구간 끝 값
int makeTree(int nodeIndex, int left, int right) {
	if (left == right) { // leaf node
		if (left <= nums.length) tree[nodeIndex] = nums[left - 1];
		else tree[nodeIndex] = 0;
		return tree[nodeIndex];
	} else { // internal node
		int mid = left + (right - left) / 2;
		tree[nodeIndex] = makeTree(nodeIndex * 2, left, mid);
		tree[nodeIndex] += makeTree(nodeIndex * 2 + 1, mid + 1, right);
		return tree[nodeIndex];
	}
}

 

임의의 구간에 대한 구간합을 반환하는 함수 query()도 위와 비슷한 방식으로 작동한다.

query(1, 1, nums.length, 3, 7); // 구간 [3, 7]의 합

// nodeIndex: 세그먼트 트리에서 해당 노드의 인덱스
// left: 세그먼트 트리에서 해당 노드의 구간 시작 값
// right: 세그먼트 트리에서 해당 노드의 구간 끝 값
// qLeft: 질의하는 구간의 시작 값
// qRight: 질의하는 구간의 끝 값
int query(int nodeIndex, int left, int right, int qLeft, int qRight) {
	if (qLeft <= left && right <= qRight) return tree[nodeIndex]; // 완전히 속함
	else if (qRight < left || right < qLeft) return 0; // 겹치는 구간이 없음
	else { // 자식 노드에게 쿼리를 포워드
		int mid = left + (right - left) / 2;
		return query(nodeIndex * 2, left, mid, qLeft, qRight)
				+ query(nodeIndex * 2 + 1, mid + 1, right, qLeft, qRight);
	}
}

 

배열의 특정 위치의 값을 변경하고 싶으면, 세그먼트 트리의 루트 노드부터 해당 leaf 노드까지 함께 업데이트 해주어야 한다.

int targetIndex = 2;
int targetValue = 5;
int diff = targetValue - nums[targetIndex];
nums[targetIndex] = targetValue;
st.update(1, 1, nums.length, targetIndex + 1, diff);

// nodeIndex: 세그먼트 트리에서 해당 노드의 인덱스
// left: 세그먼트 트리에서 해당 노드의 구간 시작 값
// right: 세그먼트 트리에서 해당 노드의 구간 끝 값
// index: 배열에서 변경하려는 원소의 인덱스
// diff: 변경하려는 값과 배열에서 원소의 값 사이의 차이
void update(int nodeIndex, int left, int right, int index, int diff) {
	if(nodeIndex >= tree.length) return;
	if(left <= index && index <= right) {
		tree[nodeIndex] += diff;
		this.nums[index] += diff;
		int mid = left + (right - left) / 2;
		update(nodeIndex * 2, left, mid, index, diff);
		update(nodeIndex * 2 + 1, mid + 1, right, index, diff);
	}
}

전체 코드

더보기

 

package DataStructure;

import java.util.Arrays;

public class SegementTreeTest {

	public static void main(String[] args) {
		// TODO Auto-generated method stub
		int nums[] = { 3, 2, 4, 5, 1, 6, 2, 7 };
		SegmentTree st = new SegmentTree(nums);
		System.out.println(st);
		
		System.out.println("sum of range [3, 7] is " + st.query(1, 1, nums.length, 3, 7));
		System.out.println("sum of range [2, 5] is " + st.query(1, 1, nums.length, 2, 5));
	
		int targetIndex = 2;
		int targetValue = 5;
		int diff = targetValue - nums[targetIndex];
		nums[targetIndex] = targetValue;
		st.update(1, 1, nums.length, targetIndex + 1, diff);
		
		System.out.println(st);
		System.out.println("sum of range [3, 7] is " + st.query(1, 1, nums.length, 3, 7));
		
	
	}

}

class SegmentTree {
	int[] nums;
	int[] tree;
	int depth;
	int leafSize;

	public SegmentTree(int[] nums) {
		super();
		this.nums = nums;

		this.depth = 0;
		while (Math.pow(2, this.depth) < nums.length) {
			this.depth++;
		}

		this.leafSize = (int) Math.pow(2, this.depth);
		this.tree = new int[(int) Math.pow(2, this.depth + 1)];
		makeTree(1, 1, nums.length);
	}

	@Override
	public String toString() {
		return "SegmentTree [nums=" + Arrays.toString(nums) + ", tree=" + Arrays.toString(tree) + ", leafSize="
				+ leafSize + ", depth=" + depth + "]";
	}

	int makeTree(int nodeIndex, int left, int right) {
		if (left == right) {
			if (left <= nums.length)
				tree[nodeIndex] = nums[left - 1];
			else
				tree[nodeIndex] = 0;
			return tree[nodeIndex];
		} else {
			int mid = left + (right - left) / 2;
			tree[nodeIndex] = makeTree(nodeIndex * 2, left, mid);
			tree[nodeIndex] += makeTree(nodeIndex * 2 + 1, mid + 1, right);
			return tree[nodeIndex];
		}
	}

	int query(int nodeIndex, int left, int right, int qLeft, int qRight) {
		if (qLeft <= left && right <= qRight)
			return tree[nodeIndex];
		else if (qRight < left || right < qLeft)
			return 0;
		else {
			int mid = left + (right - left) / 2;
			return query(nodeIndex * 2, left, mid, qLeft, qRight)
					+ query(nodeIndex * 2 + 1, mid + 1, right, qLeft, qRight);
		}
	}

	void update(int nodeIndex, int left, int right, int index, int diff) {
		if(nodeIndex >= tree.length) return;
		if(left <= index && index <= right) {
			tree[nodeIndex] += diff;
			this.nums[index] += diff;
			int mid = left + (right - left) / 2;
			update(nodeIndex * 2, left, mid, index, diff);
			update(nodeIndex * 2 + 1, mid + 1, right, index, diff);
		}
	}
}

 

'Algorithm' 카테고리의 다른 글

Dynamic Programming  (0) 2020.07.02
Binary Tree  (0) 2020.06.28
Dijkstra's Algorithm  (0) 2020.05.27
Comments