포스트

[Baekjoon] 11658번: 구간 합 구하기 3 (Platinum) - C++ 풀이

[Baekjoon] 11658번: 구간 합 구하기 3 (Platinum) - C++ 풀이

문제

문제 링크

N×N개의 수가 N×N 크기의 표에 채워져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 표의 i행 j열은 (i, j)로 나타낸다. (x1, y1)부터 (x2, y2)까지 합이란 x1 ≤ x ≤ x2, y1 ≤ y ≤ y2를 만족하는 모든 (x, y)에 있는 수의 합이다.

예를 들어, N = 4이고, 표가 아래와 같이 채워져 있는 경우를 살펴보자.

1 2 3 4
2 3 4 5
3 4 5 6
4 5 6 7

여기서 (2, 2)부터 (3, 4)까지 합을 구하면 3+4+5+4+5+6 = 27이 된다. (2, 3)을 7로 바꾸고 (2, 2)부터 (3, 4)까지 합을 구하면 3+7+5+4+5+6=30 이 된다.

표에 채워져 있는 수와 변경하는 연산과 합을 구하는 연산이 주어졌을 때, 이를 처리하는 프로그램을 작성하시오.

입력

첫째 줄에 표의 크기 N과 수행해야 하는 연산의 수 M이 주어진다. (1 ≤ N ≤ 1024, 1 ≤ M ≤ 100,000) 둘째 줄부터 N개의 줄에는 표에 채워져있는 수가 1행부터 차례대로 주어진다. 다음 M개의 줄에는 네 개의 정수 w, x, y, c 또는 다섯 개의 정수 w, x1, y1, x2, y2가 주어진다. w = 0인 경우는 (x, y)를 c (1 ≤ c ≤ 1,000)로 바꾸는 연산이고, w = 1인 경우는 (x1, y1)부터 (x2, y2)의 합을 구해 출력하는 연산이다. (1 ≤ x1 ≤ x2 ≤ N, 1 ≤ y1 ≤ y2 ≤ N) 표에 채워져 있는 수는 1,000보다 작거나 같은 자연수이다.

출력

w = 1인 입력마다 구한 합을 순서대로 한 줄에 하나씩 출력한다.

문제 조건

  • 목표: N×N 크기의 2차원 배열에서 특정 지점의 값을 수정하고, 주어진 직사각형 범위(x1, y1 ~ x2, y2) 내 원소들의 합을 효율적으로 구한다.

  • 입력 상태: 표의 크기 N(≤1024), 연산 횟수 M(≤100,000), 초기 배열 값 및 갱신/쿼리 명령이 주어진다.

  • 핵심 조건: 값의 변경(Update)이 빈번하게 일어나므로 단순 누적 합(Prefix Sum) 방식(Update 시 O(N²)) 대신, 수정과 구간 합 계산을 모두 로그 시간에 처리할 수 있는 2차원 세그먼트 트리를 사용해야 한다.


풀이

핵심 알고리즘

자료 구조

  • 시간 복잡도: O(N² + M log² N) — 초기 트리 구축 O(N²), 각 쿼리 및 업데이트 O(log N · log N). 최악의 경우 약 1,000만 번의 연산 수행.

핵심 아이디어

1차원 세그먼트 트리의 각 노드가 또 다른 세그먼트 트리를 관리하는 ‘Tree of Trees’ 구조(2차원 세그먼트 트리)를 구현한다. Y축(행)을 분할하는 세그먼트 트리를 먼저 구성하고, 각 행 노드 내부에 해당 행 범위에 속하는 X축(열)의 합을 저장하는 세그먼트 트리를 중첩하여 관리한다.

① 2차원 세그먼트 트리 초기화

Y축(행) 범위를 반으로 나누며 하향식으로 내려가고, 리프 노드(단일 행) 혹은 병합 노드(행 범위 합)에 도달할 때마다 해당 구간의 X축(열) 정보를 담은 세그먼트 트리를 생성하여 2차원 구조를 완성한다.

1
2
3
4
5
6
7
8
void init_y(int node_y, int start_y, int end_y) {
	if (start_y < end_y) {
		int mid_y = start_y + ((end_y - start_y) >> 1);
		init_y(node_y << 1, start_y, mid_y);       // 위쪽 절반
		init_y(node_y << 1 | 1, mid_y + 1, end_y); // 아래쪽 절반
	}
	init_x(node_y, start_y, end_y, 1, 1, n);
}

② 하향식 점 업데이트 (Update)

특정 좌표 (y, x)의 값을 수정할 때, Y축 트리에서 해당 y좌표를 포함하는 모든 노드를 타고 내려가며 각각의 X축 트리에서 x좌표에 해당하는 값을 갱신한다. 이때 Y축 리프 노드인 경우(flag=true)에는 새 값을 직접 쓰고, 부모 노드인 경우(flag=false)에는 두 자식 노드의 X축 트리 값을 합산하여 갱신한다.

1
2
3
4
5
6
7
8
9
10
11
12
13
void update_x(int y_id, int x_id, int left, int right, int x, int val, bool flag) {
	if (left > x || x > right) return;
	if (left < right) {
		int mid = left + ((right - left) >> 1);
		update_x(y_id, x_id << 1, left, mid, x, val, flag);
		update_x(y_id, x_id << 1 | 1, mid + 1, right, x, val, flag);
		tree[y_id][x_id] = tree[y_id][x_id << 1] + tree[y_id][x_id << 1 | 1];
		return;
	}

	if (flag) tree[y_id][x_id] = val;
	else tree[y_id][x_id] = tree[y_id << 1][x_id] + tree[y_id << 1 | 1][x_id];
}

③ 2차원 구간 합 쿼리 (Query)

구하고자 하는 Y축 범위 [start_y, end_y]에 포함되는 Y축 노드들을 선별하고, 각 노드 내부에 구축된 X축 트리에서 [start_x, end_x] 범위의 부분합을 구해 모두 더한다. Y축과 X축 모두 로그 시간(log N)이 소요되므로 최종적으로 O(log² N)의 효율성을 가진다.

1
2
3
4
5
6
7
8
9
int query(int idx, int top, int bot, int start_y, int start_x, int end_y, int end_x) {
	if (bot < start_y || end_y < top) return 0;
	if (start_y <= top && bot <= end_y) return query_x(idx, 1, 1, n, start_x, end_x);
	int mid = top + ((bot - top) >> 1);
	int yt = query(idx << 1, top, mid, start_y, start_x, end_y, end_x);
	int yb = query(idx << 1 | 1, mid+1, bot, start_y, start_x, end_y, end_x);

	return yt + yb;
}

성능

  • 메모리 : 71792 KB

  • 시간 : 288 ms

코드 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#define _CRT_SECURE_NO_WARNINGS
#include<iostream>
using namespace std;

const int MX_N = 1'025;
const int MX_M = 100'001;

int map_data[MX_N][MX_N];
int n, m;

int tree[4 * MX_N][4 * MX_N];
// X축(가로) 구간을 초기화하는 함수
void init_x(int node_y, int start_y, int end_y, int node_x, int start_x, int end_x) {
	if (start_x >= end_x) {
		if (start_y >= end_y) {
			tree[node_y][node_x] = map_data[start_y][start_x];
		}
		else {
			tree[node_y][node_x] = tree[node_y << 1][node_x] + tree[node_y << 1 | 1][node_x];
		}
		return;
	}

	int mid_x = start_x + ((end_x - start_x) >> 1);

	init_x(node_y, start_y, end_y, node_x << 1, start_x, mid_x);
	init_x(node_y, start_y, end_y, node_x << 1 | 1, mid_x + 1, end_x);

	tree[node_y][node_x] = tree[node_y][node_x << 1] + tree[node_y][node_x << 1 | 1];
}

// Y축(세로) 구간을 초기화하는 함수
void init_y(int node_y, int start_y, int end_y) {
	if (start_y < end_y) {
		int mid_y = start_y + ((end_y - start_y) >> 1);
		init_y(node_y << 1, start_y, mid_y);       // 위쪽 절반
		init_y(node_y << 1 | 1, mid_y + 1, end_y); // 아래쪽 절반
	}
	init_x(node_y, start_y, end_y, 1, 1, n);
}

int query_x(int y_id, int x_id, int left, int right, int start_x, int end_x) {
	if (right < start_x || end_x < left) return 0;
	if (start_x <= left && right <= end_x) return tree[y_id][x_id];
	int mid = left + ((right - left) >> 1);
	int xl = query_x(y_id, x_id << 1, left, mid, start_x, end_x);
	int xr = query_x(y_id, x_id << 1 | 1, mid + 1, right, start_x, end_x);

	return xl + xr;
}

int query(int idx, int top, int bot, int start_y, int start_x, int end_y, int end_x) {
	if (bot < start_y || end_y < top) return 0;
	if (start_y <= top && bot <= end_y) return query_x(idx, 1, 1, n, start_x, end_x);
	int mid = top + ((bot - top) >> 1);
	int yt = query(idx << 1, top, mid, start_y, start_x, end_y, end_x);
	int yb = query(idx << 1 | 1, mid+1, bot, start_y, start_x, end_y, end_x);

	return yt + yb;
}

void update_x(int y_id, int x_id, int left, int right, int x, int val, bool flag) {
	if (left > x || x > right) return;
	if (left < right) {
		int mid = left + ((right - left) >> 1);
		update_x(y_id, x_id << 1, left, mid, x, val, flag);
		update_x(y_id, x_id << 1 | 1, mid + 1, right, x, val, flag);
		tree[y_id][x_id] = tree[y_id][x_id << 1] + tree[y_id][x_id << 1 | 1];
		return;
	}

	if (flag) tree[y_id][x_id] = val;
	else tree[y_id][x_id] = tree[y_id << 1][x_id] + tree[y_id << 1 | 1][x_id];
}

void update(int idx, int top, int bot, int y, int x, int val) {
	if (top > y || y > bot) return;
	if (top < bot) {
		int mid = top + ((bot - top) >> 1);
		update(idx << 1, top, mid, y, x, val);
		update(idx << 1 | 1, mid + 1, bot, y, x, val);
	}
	update_x(idx, 1, 1, n, x, val, (top==bot));
}


int main() {
	ios::sync_with_stdio(0);
	cin.tie(0);
	//freopen("test.txt", "r", stdin);


	cin >> n >> m;
	for (int y = 1; y <= n; ++y)
	{
		for (int x = 1; x <= n; ++x)
		{
			cin >> map_data[y][x];
		}
	}
	init_y(1, 1, n);
	int w, x1, y1, x2, y2, c;
	while (m--)
	{
		cin >> w;
		if (!w) {
			cin >> x1 >> y1 >> c;
			update(1, 1, n, x1, y1, c);
		}
		else {
			cin >> x1 >> y1 >> x2 >> y2;
			cout << query(1, 1, n, x1, y1, x2, y2) << '\n';
		}
	}
	return 0;
}

이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.