포스트

[Baekjoon] 11658번: 구간 합 구하기 3 (Platinum) - C++ 풀이

[Baekjoon] 11658번: 구간 합 구하기 3 (Platinum) - C++ 풀이

문제 링크

문제 설명

N×N개의 수가 N×N 크기의 표에 채워져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 표의 i행 j열은 (i, j)로 나타낸다. (x1, y1)부터 (x2, y2)까지 합이란 x1 ≤ x ≤ x2, y1 ≤ y ≤ y2를 만족하는 모든 (x, y)에 있는 수의 합이다.

예를 들어, N = 4이고, 표가 아래와 같이 채워져 있는 경우를 살펴보자.

1 2 3 4
2 3 4 5
3 4 5 6
4 5 6 7

여기서 (2, 2)부터 (3, 4)까지 합을 구하면 3+4+5+4+5+6 = 27이 된다. (2, 3)을 7로 바꾸고 (2, 2)부터 (3, 4)까지 합을 구하면 3+7+5+4+5+6=30 이 된다.

표에 채워져 있는 수와 변경하는 연산과 합을 구하는 연산이 주어졌을 때, 이를 처리하는 프로그램을 작성하시오.

입력

첫째 줄에 표의 크기 N과 수행해야 하는 연산의 수 M이 주어진다. (1 ≤ N ≤ 1024, 1 ≤ M ≤ 100,000) 둘째 줄부터 N개의 줄에는 표에 채워져있는 수가 1행부터 차례대로 주어진다. 다음 M개의 줄에는 네 개의 정수 w, x, y, c 또는 다섯 개의 정수 w, x1, y1, x2, y2가 주어진다. w = 0인 경우는 (x, y)를 c (1 ≤ c ≤ 1,000)로 바꾸는 연산이고, w = 1인 경우는 (x1, y1)부터 (x2, y2)의 합을 구해 출력하는 연산이다. (1 ≤ x1 ≤ x2 ≤ N, 1 ≤ y1 ≤ y2 ≤ N) 표에 채워져 있는 수는 1,000보다 작거나 같은 자연수이다.

출력

w = 1인 입력마다 구한 합을 순서대로 한 줄에 하나씩 출력한다.

코드 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#define _CRT_SECURE_NO_WARNINGS
#include<iostream>
using namespace std;

const int MX_N = 1'025;
const int MX_M = 100'001;

int map_data[MX_N][MX_N];
int n, m;

int tree[4 * MX_N][4 * MX_N];
// X축(가로) 구간을 초기화하는 함수
void init_x(int node_y, int start_y, int end_y, int node_x, int start_x, int end_x) {
	if (start_x >= end_x) {
		if (start_y >= end_y) {
			tree[node_y][node_x] = map_data[start_y][start_x];
		}
		else {
			tree[node_y][node_x] = tree[node_y << 1][node_x] + tree[node_y << 1 | 1][node_x];
		}
		return;
	}

	int mid_x = start_x + ((end_x - start_x) >> 1);

	init_x(node_y, start_y, end_y, node_x << 1, start_x, mid_x);
	init_x(node_y, start_y, end_y, node_x << 1 | 1, mid_x + 1, end_x);

	tree[node_y][node_x] = tree[node_y][node_x << 1] + tree[node_y][node_x << 1 | 1];
}

// Y축(세로) 구간을 초기화하는 함수
void init_y(int node_y, int start_y, int end_y) {
	if (start_y < end_y) {
		int mid_y = start_y + ((end_y - start_y) >> 1);
		init_y(node_y << 1, start_y, mid_y);       // 위쪽 절반
		init_y(node_y << 1 | 1, mid_y + 1, end_y); // 아래쪽 절반
	}
	init_x(node_y, start_y, end_y, 1, 1, n);
}

int query_x(int y_id, int x_id, int left, int right, int start_x, int end_x) {
	if (right < start_x || end_x < left) return 0;
	if (start_x <= left && right <= end_x) return tree[y_id][x_id];
	int mid = left + ((right - left) >> 1);
	int xl = query_x(y_id, x_id << 1, left, mid, start_x, end_x);
	int xr = query_x(y_id, x_id << 1 | 1, mid + 1, right, start_x, end_x);

	return xl + xr;
}

int query(int idx, int top, int bot, int start_y, int start_x, int end_y, int end_x) {
	if (bot < start_y || end_y < top) return 0;
	if (start_y <= top && bot <= end_y) return query_x(idx, 1, 1, n, start_x, end_x);
	int mid = top + ((bot - top) >> 1);
	int yt = query(idx << 1, top, mid, start_y, start_x, end_y, end_x);
	int yb = query(idx << 1 | 1, mid+1, bot, start_y, start_x, end_y, end_x);

	return yt + yb;
}

void update_x(int y_id, int x_id, int left, int right, int x, int val, bool flag) {
	if (left > x || x > right) return;
	if (left < right) {
		int mid = left + ((right - left) >> 1);
		update_x(y_id, x_id << 1, left, mid, x, val, flag);
		update_x(y_id, x_id << 1 | 1, mid + 1, right, x, val, flag);
		tree[y_id][x_id] = tree[y_id][x_id << 1] + tree[y_id][x_id << 1 | 1];
		return;
	}

	if (flag) tree[y_id][x_id] = val;
	else tree[y_id][x_id] = tree[y_id << 1][x_id] + tree[y_id << 1 | 1][x_id];
}

void update(int idx, int top, int bot, int y, int x, int val) {
	if (top > y || y > bot) return;
	if (top < bot) {
		int mid = top + ((bot - top) >> 1);
		update(idx << 1, top, mid, y, x, val);
		update(idx << 1 | 1, mid + 1, bot, y, x, val);
	}
	update_x(idx, 1, 1, n, x, val, (top==bot));
}


int main() {
	ios::sync_with_stdio(0);
	cin.tie(0);
	//freopen("test.txt", "r", stdin);


	cin >> n >> m;
	for (int y = 1; y <= n; ++y)
	{
		for (int x = 1; x <= n; ++x)
		{
			cin >> map_data[y][x];
		}
	}
	init_y(1, 1, n);
	int w, x1, y1, x2, y2, c;
	while (m--)
	{
		cin >> w;
		if (!w) {
			cin >> x1 >> y1 >> c;
			update(1, 1, n, x1, y1, c);
		}
		else {
			cin >> x1 >> y1 >> x2 >> y2;
			cout << query(1, 1, n, x1, y1, x2, y2) << '\n';
		}
	}
	return 0;
}
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.