Skip to main content Link Menu Expand (external link) Document Search Copy Copied

07578번 공장

Problem link: https://www.acmicpc.net/problem/7578

Keywords: Inversion Counting, Merge Sort

문제를 차근차근 생각해보면 이 문제는 별 다를 것이 없는 inversion counting 문제이다.

널리 알려져있듯이 inversion counting 문제는 merge sort를 쓰면 무리 없이 풀리므로, inversion counting을 푸는 방법을 여기에 적는 것은 전자 낭비일 것 같다.

그래서, 여기서는 이 문제가 왜 inversion counting 문제로 reduction되는지 정도를 적어두려한다.

주어진 입력이 아래와 같았다고 해보자.

A: 132 392 311 351 231
B: 392 351 132 311 231

가장 무식하게 풀면 아래와 같은 모든 경우를 헤아려서 그 수를 더해주면 된다.

  • 132번 기계 입장에서는 A열에서는 132번 보다 뒤에 있지만, B열에서는 132번 보다 앞에 있는 기계 수를 세어준다.
  • 392번 기계 입장에서는 A열에서는 392번 보다 뒤에 있지만, B열에서는 392번 보다 앞에 있는 기계 수를 세어준다.

동일한 것을 조금만 달리 표현해보자.

B열의 기계들에 대해서 각 기계가 A열에서 어느 위치에 있었는지를 POS_AT_A와 같이 표현해보자.

POS_AT_A에서 inversion을 세어주면 이게 곧 위에서 구했던 교차 케이블 쌍의 수와 같음을 알 수 있다.

  • POS_AT_A에서 두 원소 쌍 POS_AT_A[i], POS_AT_A[j]의 대소관계는 곧 A열에서 두 기계 i, j의 순서를 나타낸다(작은 것이 더 앞에 온다).
  • 따라서, POS_AT_A[i] < POS_AT_A[j] && i > j가 의미하는 바는 기계 i가 기계 j보다 A열에서는 앞에 위치 해있었는데, B열에서는 기계 i보다 뒤에 있다는 것이다.
A       : 132 392 311 351 231
B       : 392 351 132 311 231
POS     :   0   1   2   3   4
POS_AT_A:   1   3   0   2   4

아래는 merge sort를 사용한 솔루션이다.

#include <cstdint>
#include <iostream>
#include <map>

using namespace std;

constexpr int MAX_N = 500'000;
constexpr int MAX_MACHINE_NO = 1'000'000;

// Inputs.
int N;
int ARR[MAX_N];
int POS[MAX_MACHINE_NO + 1];

// Solution.
int ARR_COPY[MAX_N];

int64_t CountInversion(int s, int e) {
  if (s >= e) {
    return 0;
  }
  int m = (s + e) / 2;
  int64_t inv = CountInversion(s, m) + CountInversion(m + 1, e);
  int offset = s;
  int left = s;
  int right = m + 1;
  while (left <= m && right <= e) {
    if (ARR[left] < ARR[right]) {
      ARR_COPY[offset++] = ARR[left++];
    } else {
      ARR_COPY[offset++] = ARR[right++];
      inv += m - left + 1;
    }
  }
  while (left <= m) {
    ARR_COPY[offset++] = ARR[left++];
  }
  while (right <= e) {
    ARR_COPY[offset++] = ARR[right++];
  }
  for (int i = s; i <= e; ++i) {
    ARR[i] = ARR_COPY[i];
  }
  return inv;
}

int64_t Solve() { return CountInversion(0, N - 1); }

int main() {
  // For faster IO.
  ios_base::sync_with_stdio(false);
  cout.tie(nullptr);
  cin.tie(nullptr);

  // Read inputs.
  cin >> N;
  for (int i = 0; i < N; ++i) {
    int machine_no;
    cin >> machine_no;
    POS[machine_no] = i;
  }
  for (int i = 0; i < N; ++i) {
    int machine_no;
    cin >> machine_no;
    ARR[i] = POS[machine_no];
  }

  // Solve.
  cout << Solve() << '\n';

  return 0;
}