세그먼트 트리 (Segment Tree)

개념

길이 N의 배열이 하나 있다고 가정하자.

1
1
arr = [5, 4, 3, 6, 5, 3, 1, 2]

이 길이 N 배열의 임의의 한 구간에 대한 합을 묻는 질문이 M회 들어온다면??

1
2
3
4
5
6
7
8
9
Q : 2~6번째 원소의 합이 얼마야?
A : 21이야
Q : 1~2번째 원소의 합이 얼마야?
A : 9야
Q : 1~8번째 원소의 합이 얼마야?
A : 29야
Q : 5~6번째 원소의 합이 얼마야?
A : 8이야
...

매번 배열을 참조해서 답변을 하기에는 무언가 비효율적이라는 것이 직관적으로 느껴질 것이다. 한 번의 질문마다 최대 N회 참조를 해야 한다.

-> 시간복잡도 O(NM)

prefix-sum array

여러 가지 개선방안이 있겠지만, 이런 상황에서 상당히 유용한 기법 중 하나는 배열의 누적 합을 미리 계산해두는 것이다

1
2
3
4
5
arr = [5, 4, 3, 6, 5, 3, 1, 2]
sum_arr = [arr[0]]
for i in arr[1:]:
    sum_arr.append(sum_arr[-1]+i)
sum_arr
1
>> [5, 9, 12, 18, 23, 26, 27, 29]

이렇게 n 번째 원소까지의 합이 저장된 배열이 있다면, 구간 A~B의 합을 구하고 싶을 때 ‘B까지의 합’ - ‘A-1까지의 합’을 구함으로써 O(1) 시간복잡도에 구할 수 있을 것이다.

1
2
3
4
5
6
def get_sum(A, B):
    if A == 0:
        return sum_arr[B]
    else:
        sum_arr[B] - sum_arr[A-1]
    return 

그럼 질문의 형태가 다음과 같이 바뀐다면??

1
2
3
4
5
6
7
8
9
10
11
Q : 2~6번째 원소의 합이 얼마야?
A : 21이야
U : 3번째 원소 값을 8로 바꿔줘
Q : 1~2번째 원소의 합이 얼마야?
A : 9야
Q : 1~8번째 원소의 합이 얼마야?
A : 34야
U : 6번째 원소 값을 50으로 바꿔줘
Q : 5~6번째 원소의 합이 얼마야?
A : 55야
...

prefix-sum array 를 이용하면 Q 에 대한 대답은 O(1)만에 가능하겠지만, U 에 대한 배열 업데이트는 O(N)이 되므로, 결국 최종적인 시간복잡도는 O(NM)이 될 것이다

segment tree

prefix-sum array 가 1~N까지의 원소의 합을 저장했다면, segment array는 구간을 두 갈래로 나눠가면서 해당 구간의 합을 기록하는 식이다.

1
1
arr = [5, 4, 3, 6, 5, 3, 1, 2]

위의 배열을 예시로 들면, 만들어지는 segment tree는 다음과 같게 된다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
seg_tree = [0,  # 첫 번째 자리는 비워 둠. index 계산을 편하게 하기 위함
            29,  # 1~8 번 째 원소의 합
            18,  # 1~4 번 째 원소의 합
            11,  # 5~8 번 째 원소의 합
            9,  # 1~2 번 째 원소의 합
            9,  # 3~4 번 째 원소의 합
            8,  # 5~6 번 째 원소의 합
            3,  # 7~8 번 째 원소의 합
            5,  # 1 번 째 원소의 합
            4,  # 2 번 째 원소의 합
            3,  # 3 번 째 원소의 합
            6,  # 4 번 째 원소의 합
            5,  # 5 번 째 원소의 합
            3,  # 6 번 째 원소의 합
            1,  # 7 번 째 원소의 합
            2  # 8 번 째 원소의 합
]

위 트리의 값을 유지시켜줌으로써 구간 A~B의 합을 구하는 데에 O(log(N)) 시간복잡도가 필요하게 된다.

트리의 값을 유지시켜주기 위해서는 U 를 처리해야 하는데, 이 역시 O(log(N)) 시간복잡도에 처리가 가능하다.

구현

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

seg_tree = [0 for i in range(4*N)]

def update(X, V, N = 1, S = 1, E = N):  # 값 업데이트
    if S == E:  # 현재 참조중인 원소의 표현범위가 크기 1이라면 (종착지에 도달했다면)
        diff = V - seg_tree[N]
        seg_tree[N] = V
        return diff
    else:
        mid = (S + E) // 2
        if X <= mid:  # 현재 참조중인 표현범위를 절반으로 나누었을 때, 목표 index가 왼쪽에 있다면
            diff = update(X, V, 2 * N, S, mid)
            seg_tree[N] += diff
            return diff
        else:  # 현재 참조중인 표현범위를 절반으로 나누었을 때, 목표 index가 오른쪽에 있다면
            diff = update(X, V, 2 * N + 1, mid + 1, E)
            seg_tree[N] += diff
            return diff


def query(L, R, N = 1, S = 1, E = N):  # 값 조회
    if R < S or E < L:  # 현재 참조중인 원소의 표현범위가 내 목표와 겹치지 않을 때
        return 0
    elif L <= S and E <= R:  # 현재 참조중인 원소의 표현범위가 내 목표 안에 포함될 때
        return seg_tree[N]
    else:  # 현재 참조중인 원소의 표현범위가 내 목표와 일부 겹칠 때
        mid = (S + E) // 2
        
        # 구간을 반으로 쪼개서 재조회
        return query(L, R, N * 2, S, mid) + query(L, R, N * 2 + 1, mid + 1, E)

변형 및 활용

기본 트리 [구간 합 구하기]

key가 튜플로 저장되는 트리 [최솟값과 최댓값]

[구간 곱 구하기]

두 개의 값을 저장하는 트리 [구간 합 구하기]

각주

https://en.wikipedia.org/wiki/Segment_tree