[Baekjoon] 11658번: 구간 합 구하기 3 (Platinum) - C++ 풀이
문제 설명
N×N개의 수가 N×N 크기의 표에 채워져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 표의 i행 j열은 (i, j)로 나타낸다. (x1, y1)부터 (x2, y2)까지 합이란 x1 ≤ x ≤ x2, y1 ≤ y ≤ y2를 만족하는 모든 (x, y)에 있는 수의 합이다.
예를 들어, N = 4이고, 표가 아래와 같이 채워져 있는 경우를 살펴보자.
| 1 | 2 | 3 | 4 |
| 2 | 3 | 4 | 5 |
| 3 | 4 | 5 | 6 |
| 4 | 5 | 6 | 7 |
여기서 (2, 2)부터 (3, 4)까지 합을 구하면 3+4+5+4+5+6 = 27이 된다. (2, 3)을 7로 바꾸고 (2, 2)부터 (3, 4)까지 합을 구하면 3+7+5+4+5+6=30 이 된다.
표에 채워져 있는 수와 변경하는 연산과 합을 구하는 연산이 주어졌을 때, 이를 처리하는 프로그램을 작성하시오.
입력
첫째 줄에 표의 크기 N과 수행해야 하는 연산의 수 M이 주어진다. (1 ≤ N ≤ 1024, 1 ≤ M ≤ 100,000) 둘째 줄부터 N개의 줄에는 표에 채워져있는 수가 1행부터 차례대로 주어진다. 다음 M개의 줄에는 네 개의 정수 w, x, y, c 또는 다섯 개의 정수 w, x1, y1, x2, y2가 주어진다. w = 0인 경우는 (x, y)를 c (1 ≤ c ≤ 1,000)로 바꾸는 연산이고, w = 1인 경우는 (x1, y1)부터 (x2, y2)의 합을 구해 출력하는 연산이다. (1 ≤ x1 ≤ x2 ≤ N, 1 ≤ y1 ≤ y2 ≤ N) 표에 채워져 있는 수는 1,000보다 작거나 같은 자연수이다.
출력
w = 1인 입력마다 구한 합을 순서대로 한 줄에 하나씩 출력한다.
코드 (C++)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#define _CRT_SECURE_NO_WARNINGS
#include<iostream>
using namespace std;
const int MX_N = 1'025;
const int MX_M = 100'001;
int map_data[MX_N][MX_N];
int n, m;
int tree[4 * MX_N][4 * MX_N];
// X축(가로) 구간을 초기화하는 함수
void init_x(int node_y, int start_y, int end_y, int node_x, int start_x, int end_x) {
if (start_x >= end_x) {
if (start_y >= end_y) {
tree[node_y][node_x] = map_data[start_y][start_x];
}
else {
tree[node_y][node_x] = tree[node_y << 1][node_x] + tree[node_y << 1 | 1][node_x];
}
return;
}
int mid_x = start_x + ((end_x - start_x) >> 1);
init_x(node_y, start_y, end_y, node_x << 1, start_x, mid_x);
init_x(node_y, start_y, end_y, node_x << 1 | 1, mid_x + 1, end_x);
tree[node_y][node_x] = tree[node_y][node_x << 1] + tree[node_y][node_x << 1 | 1];
}
// Y축(세로) 구간을 초기화하는 함수
void init_y(int node_y, int start_y, int end_y) {
if (start_y < end_y) {
int mid_y = start_y + ((end_y - start_y) >> 1);
init_y(node_y << 1, start_y, mid_y); // 위쪽 절반
init_y(node_y << 1 | 1, mid_y + 1, end_y); // 아래쪽 절반
}
init_x(node_y, start_y, end_y, 1, 1, n);
}
int query_x(int y_id, int x_id, int left, int right, int start_x, int end_x) {
if (right < start_x || end_x < left) return 0;
if (start_x <= left && right <= end_x) return tree[y_id][x_id];
int mid = left + ((right - left) >> 1);
int xl = query_x(y_id, x_id << 1, left, mid, start_x, end_x);
int xr = query_x(y_id, x_id << 1 | 1, mid + 1, right, start_x, end_x);
return xl + xr;
}
int query(int idx, int top, int bot, int start_y, int start_x, int end_y, int end_x) {
if (bot < start_y || end_y < top) return 0;
if (start_y <= top && bot <= end_y) return query_x(idx, 1, 1, n, start_x, end_x);
int mid = top + ((bot - top) >> 1);
int yt = query(idx << 1, top, mid, start_y, start_x, end_y, end_x);
int yb = query(idx << 1 | 1, mid+1, bot, start_y, start_x, end_y, end_x);
return yt + yb;
}
void update_x(int y_id, int x_id, int left, int right, int x, int val, bool flag) {
if (left > x || x > right) return;
if (left < right) {
int mid = left + ((right - left) >> 1);
update_x(y_id, x_id << 1, left, mid, x, val, flag);
update_x(y_id, x_id << 1 | 1, mid + 1, right, x, val, flag);
tree[y_id][x_id] = tree[y_id][x_id << 1] + tree[y_id][x_id << 1 | 1];
return;
}
if (flag) tree[y_id][x_id] = val;
else tree[y_id][x_id] = tree[y_id << 1][x_id] + tree[y_id << 1 | 1][x_id];
}
void update(int idx, int top, int bot, int y, int x, int val) {
if (top > y || y > bot) return;
if (top < bot) {
int mid = top + ((bot - top) >> 1);
update(idx << 1, top, mid, y, x, val);
update(idx << 1 | 1, mid + 1, bot, y, x, val);
}
update_x(idx, 1, 1, n, x, val, (top==bot));
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
//freopen("test.txt", "r", stdin);
cin >> n >> m;
for (int y = 1; y <= n; ++y)
{
for (int x = 1; x <= n; ++x)
{
cin >> map_data[y][x];
}
}
init_y(1, 1, n);
int w, x1, y1, x2, y2, c;
while (m--)
{
cin >> w;
if (!w) {
cin >> x1 >> y1 >> c;
update(1, 1, n, x1, y1, c);
}
else {
cin >> x1 >> y1 >> x2 >> y2;
cout << query(1, 1, n, x1, y1, x2, y2) << '\n';
}
}
return 0;
}