07578번 공장
Problem link: https://www.acmicpc.net/problem/7578
Keywords: Inversion Counting, Merge Sort
문제를 차근차근 생각해보면 이 문제는 별 다를 것이 없는 inversion counting 문제이다.
널리 알려져있듯이 inversion counting 문제는 merge sort를 쓰면 무리 없이 풀리므로, inversion counting을 푸는 방법을 여기에 적는 것은 전자 낭비일 것 같다.
그래서, 여기서는 이 문제가 왜 inversion counting 문제로 reduction되는지 정도를 적어두려한다.
주어진 입력이 아래와 같았다고 해보자.
A: 132 392 311 351 231
B: 392 351 132 311 231
가장 무식하게 풀면 아래와 같은 모든 경우를 헤아려서 그 수를 더해주면 된다.
132번 기계 입장에서는 A열에서는132번 보다 뒤에 있지만, B열에서는132번 보다 앞에 있는 기계 수를 세어준다.392번 기계 입장에서는 A열에서는392번 보다 뒤에 있지만, B열에서는392번 보다 앞에 있는 기계 수를 세어준다.- …
동일한 것을 조금만 달리 표현해보자.
B열의 기계들에 대해서 각 기계가 A열에서 어느 위치에 있었는지를 POS_AT_A와 같이 표현해보자.
POS_AT_A에서 inversion을 세어주면 이게 곧 위에서 구했던 교차 케이블 쌍의 수와 같음을 알 수 있다.
POS_AT_A에서 두 원소 쌍POS_AT_A[i],POS_AT_A[j]의 대소관계는 곧 A열에서 두 기계i,j의 순서를 나타낸다(작은 것이 더 앞에 온다).- 따라서,
POS_AT_A[i] < POS_AT_A[j] && i > j가 의미하는 바는 기계i가 기계j보다 A열에서는 앞에 위치 해있었는데, B열에서는 기계i보다 뒤에 있다는 것이다.
A : 132 392 311 351 231
B : 392 351 132 311 231
POS : 0 1 2 3 4
POS_AT_A: 1 3 0 2 4
아래는 merge sort를 사용한 솔루션이다.
#include <cstdint>
#include <iostream>
#include <map>
using namespace std;
constexpr int MAX_N = 500'000;
constexpr int MAX_MACHINE_NO = 1'000'000;
// Inputs.
int N;
int ARR[MAX_N];
int POS[MAX_MACHINE_NO + 1];
// Solution.
int ARR_COPY[MAX_N];
int64_t CountInversion(int s, int e) {
if (s >= e) {
return 0;
}
int m = (s + e) / 2;
int64_t inv = CountInversion(s, m) + CountInversion(m + 1, e);
int offset = s;
int left = s;
int right = m + 1;
while (left <= m && right <= e) {
if (ARR[left] < ARR[right]) {
ARR_COPY[offset++] = ARR[left++];
} else {
ARR_COPY[offset++] = ARR[right++];
inv += m - left + 1;
}
}
while (left <= m) {
ARR_COPY[offset++] = ARR[left++];
}
while (right <= e) {
ARR_COPY[offset++] = ARR[right++];
}
for (int i = s; i <= e; ++i) {
ARR[i] = ARR_COPY[i];
}
return inv;
}
int64_t Solve() { return CountInversion(0, N - 1); }
int main() {
// For faster IO.
ios_base::sync_with_stdio(false);
cout.tie(nullptr);
cin.tie(nullptr);
// Read inputs.
cin >> N;
for (int i = 0; i < N; ++i) {
int machine_no;
cin >> machine_no;
POS[machine_no] = i;
}
for (int i = 0; i < N; ++i) {
int machine_no;
cin >> machine_no;
ARR[i] = POS[machine_no];
}
// Solve.
cout << Solve() << '\n';
return 0;
}