[Baekjoon] 1517번: 버블 소트 (Platinum) - C++ 풀이
문제
N개의 수로 이루어진 수열 A[1], A[2], …, A[N]이 있다. 이 수열에 대해서 버블 소트를 수행할 때, Swap이 총 몇 번 발생하는지 알아내는 프로그램을 작성하시오.
버블 소트는 서로 인접해 있는 두 수를 바꿔가며 정렬하는 방법이다. 예를 들어 수열이 3 2 1 이었다고 하자. 이 경우에는 인접해 있는 3, 2가 바뀌어야 하므로 2 3 1 이 된다. 다음으로는 3, 1이 바뀌어야 하므로 2 1 3 이 된다. 다음에는 2, 1이 바뀌어야 하므로 1 2 3 이 된다. 그러면 더 이상 바꿔야 할 경우가 없으므로 정렬이 완료된다.
입력
첫째 줄에 N(1 ≤ N ≤ 500,000)이 주어진다. 다음 줄에는 N개의 정수로 A[1], A[2], …, A[N]이 주어진다. 각각의 A[i]는 0 ≤ |A[i]| ≤ 1,000,000,000의 범위에 들어있다.
출력
첫째 줄에 Swap 횟수를 출력한다
문제 조건
-
목표: 배열을 정렬할 때 발생하는 버블 소트의 총 스왑(Swap) 횟수를 역전 현상(Inversion Count) 계산을 통해 구한다.
-
입력 상태: N은 최대 500,000이며, 각 원소는 0에서 10억 사이의 절댓값을 가지는 정수 수열이 주어진다.
-
핵심 조건: 버블 소트의 스왑 횟수는 배열 내에서 i < j일 때 A[i] > A[j]를 만족하는 쌍(Inversion)의 개수와 동일하므로, N이 크기 때문에 O(N^2)인 버블 소트 대신 O(N log N)의 효율적인 알고리즘을 사용해야 한다.
풀이
핵심 알고리즘
자료 구조
- 시간 복잡도: O(N log N) — N=500,000일 때 약 500,000 * 19회의 연산으로 제한 시간 내 처리가 가능하다.
핵심 아이디어
병합 정렬(Merge Sort)의 과정에서 두 개의 정렬된 부분 배열을 합칠 때, 오른쪽 부분 배열의 원소가 왼쪽 부분 배열의 원소보다 작아서 먼저 선택되는 경우를 추적한다. 이때 왼쪽 부분 배열에 남아 있는 모든 원소는 현재 선택된 오른쪽 원소보다 크므로, 그 개수만큼 스왑(역전)이 발생한 것으로 간주하여 합산한다.
① 분할 정복을 통한 병합 정렬 구조
배열을 더 이상 나눌 수 없을 때까지 절반으로 나누어 정렬된 상태로 병합해 나간다. 이 과정에서 각 단계의 병합(merge) 함수를 호출하여 정렬된 상태를 유지함과 동시에 역전 현상을 카운트할 준비를 한다.
1
2
3
4
5
6
7
8
9
10
11
void Msort(int left, int right)
{
int mid = (left + right) / 2;
if (left < right)
{
Msort(left, mid);
Msort(mid + 1, right);
merge(left, mid, right);
}
}
② 병합 과정에서의 스왑(역전) 횟수 계산
두 부분 배열을 비교할 때 오른쪽 배열의 원소(arr[j])가 왼쪽 배열의 원소(arr[i])보다 작다면, arr[j]는 현재 왼쪽 배열에 남아있는 모든 원소(arr[i]부터 arr[mid])보다 작다는 것을 의미한다. 따라서 ‘mid - i + 1’만큼의 역전이 발생한 것이며, 이를 누적하여 전체 스왑 횟수를 구한다.
1
2
3
4
5
6
7
8
9
10
11
12
while (i <= mid && j <= right)
{
if (arr[i] <= arr[j])
{
set[k++] = arr[i++];
}
else
{
cnt += (long long)(mid - i + 1);
set[k++] = arr[j++];
}
}
③ 결과 저장 및 자료형 선택
N이 500,000일 때 최대 스왑 횟수는 N(N-1)/2로 약 1,250억에 달하므로, 32비트 정수형(int)의 범위를 훨씬 초과한다. 따라서 이를 안전하게 저장하기 위해 64비트 정수형인 long long을 사용하여 오버플로우를 방지한다.
1
2
3
4
5
6
7
long long arr[500005];
long long set[500005];
long long cnt = 0;
// ... 중략 ...
cout << cnt;
성능
-
메모리 : 9832 KB
-
시간 : 108 ms
코드 (C++)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <iostream>
#include <string>
using namespace std;
long long arr[500005];
long long set[500005];
long long cnt = 0;
void merge(int left, int mid, int right)
{
int i, j, k, l;
i = left;
j = mid + 1;
k = left;
while (i <= mid && j <= right)
{
if (arr[i] <= arr[j])
{
set[k++] = arr[i++];
}
else
{
cnt += (long long)(mid - i + 1);
set[k++] = arr[j++];
}
}
if (i > mid)
{
for (l = j; l <= right; l++)
{
set[k++] = arr[l];
}
}
else
{
for (l = i; l <= mid; l++)
{
set[k++] = arr[l];
}
}
for (l = left; l <= right; l++)
{
arr[l] = set[l];
}
}
void Msort(int left, int right)
{
int mid = (left + right) / 2;
if (left < right)
{
Msort(left, mid);
Msort(mid + 1, right);
merge(left, mid, right);
}
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int N;
cin >> N;
for (int i = 0; i < N; i++)
cin >> arr[i];
Msort(0, N - 1);
// cout<<"( ";
// for (int i = 0; i < N; i++)
// cout<<arr[i]<<" ";
// cout<<")\n";
cout << cnt;
}