結果

問題 No.590 Replacement
ユーザー 夕叢霧香(ゆうむらきりか)夕叢霧香(ゆうむらきりか)
提出日時 2017-11-04 00:47:17
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 122 ms / 2,000 ms
コード長 4,730 bytes
コンパイル時間 1,373 ms
コンパイル使用メモリ 104,916 KB
実行使用メモリ 18,052 KB
最終ジャッジ日時 2024-11-23 15:34:23
合計ジャッジ時間 5,804 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
5,248 KB
testcase_01 AC 2 ms
5,248 KB
testcase_02 AC 1 ms
5,248 KB
testcase_03 AC 2 ms
5,248 KB
testcase_04 AC 1 ms
5,248 KB
testcase_05 AC 1 ms
5,248 KB
testcase_06 AC 2 ms
5,248 KB
testcase_07 AC 7 ms
5,248 KB
testcase_08 AC 5 ms
5,248 KB
testcase_09 AC 8 ms
5,248 KB
testcase_10 AC 5 ms
5,248 KB
testcase_11 AC 3 ms
5,248 KB
testcase_12 AC 8 ms
5,248 KB
testcase_13 AC 36 ms
5,248 KB
testcase_14 AC 53 ms
5,248 KB
testcase_15 AC 5 ms
5,248 KB
testcase_16 AC 14 ms
5,248 KB
testcase_17 AC 23 ms
5,248 KB
testcase_18 AC 83 ms
7,504 KB
testcase_19 AC 83 ms
6,784 KB
testcase_20 AC 19 ms
5,248 KB
testcase_21 AC 99 ms
7,692 KB
testcase_22 AC 102 ms
9,052 KB
testcase_23 AC 107 ms
8,300 KB
testcase_24 AC 106 ms
8,140 KB
testcase_25 AC 108 ms
7,680 KB
testcase_26 AC 107 ms
7,684 KB
testcase_27 AC 109 ms
8,280 KB
testcase_28 AC 112 ms
7,296 KB
testcase_29 AC 113 ms
7,296 KB
testcase_30 AC 111 ms
7,168 KB
testcase_31 AC 111 ms
8,140 KB
testcase_32 AC 109 ms
8,236 KB
testcase_33 AC 2 ms
5,248 KB
testcase_34 AC 2 ms
5,248 KB
testcase_35 AC 2 ms
5,248 KB
testcase_36 AC 86 ms
9,504 KB
testcase_37 AC 87 ms
9,500 KB
testcase_38 AC 79 ms
11,848 KB
testcase_39 AC 83 ms
11,848 KB
testcase_40 AC 91 ms
11,540 KB
testcase_41 AC 89 ms
11,664 KB
testcase_42 AC 78 ms
10,812 KB
testcase_43 AC 80 ms
9,632 KB
testcase_44 AC 122 ms
18,052 KB
testcase_45 AC 85 ms
9,508 KB
testcase_46 AC 80 ms
9,508 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <algorithm>
#include <cassert>
#include <iostream>
#include <map>
#include <vector>

using namespace std;
typedef long long lint;
typedef pair<int, int> pii;

const lint MOD = 1000000007;


int gcd(int x, int y) {
  while (y != 0) {
    int remainder = x % y;
    x = y;
    y = remainder;
  }
  return x;
}

vector<int> inverse(const vector<int> &permutation) {
  int n = permutation.size();
  vector<int> result(n);
  for (int i = 0; i < n; ++i) {
    result[permutation[i]] = i;
  }
  return result;
}

void cycles(const vector<int> &permutation, vector<int> &cycle_id, vector<int> &period, vector<int> &order) {
  int n = permutation.size();
  cycle_id = vector<int>(n, -1);
  order = vector<int>(n);
  vector<int> periods;
  int count = 0;
  for (int i = 0; i < n; ++i) {
    if (cycle_id[i] >= 0) {
      continue;
    }
    cycle_id[i] = count;
    order[i] = 0;
    int p = 1;
    int current = permutation[i];
    while (current != i) {
      cycle_id[current] = count;
      order[current] = p;
      current = permutation[current];
      p++;
    }
    periods.push_back(p);
    count++;
  }
  period = periods;
}

// http://www.geeksforgeeks.org/chinese-remainder-theorem-set-2-implementation/

// Returns modulo inverse of a with respect to m using extended
// Euclid Algorithm. Refer below post for details:
// http://www.geeksforgeeks.org/multiplicative-inverse-under-modulo-m/
lint inv(lint a, lint m)
{
    lint m0 = m, t, q;
    lint x0 = 0, x1 = 1;
 
    if (m == 1)
       return 0;
 
    // Apply extended Euclid Algorithm
    while (a > 1)
    {
        // q is quotient
        q = a / m;
 
        t = m;
 
        // m is remainder now, process same as
        // euclid's algo
        m = a % m, a = t;
 
        t = x0;
 
        x0 = x1 - q * x0;
 
        x1 = t;
    }
 
    // Make x1 positive
    if (x1 < 0)
       x1 += m0;
 
    return x1;
}
 
// k is size of num[] and rem[].  Returns the smallest
// number x such that:
//  x % num[0] = rem[0],
//  x % num[1] = rem[1],
//  ..................
//  x % num[k-2] = rem[k-1]
// Assumption: Numbers in num[] are pairwise coprime
// (gcd for every pair is 1)
lint findMinX(lint num[], lint rem[], int k)
{
    // Compute product of all numbers
    lint prod = 1;
    for (int i = 0; i < k; i++)
        prod *= num[i];
 
    // Initialize result
    lint result = 0;
 
    // Apply above formula
    for (int i = 0; i < k; i++)
    {
        lint pp = prod / num[i];
        result += rem[i] * inv(pp, num[i]) * pp;
    }
 
    return result % prod;
}


lint crt(int x, int y, int ma, int mb, int g) {
  lint num[2] = {ma / g, mb / g};
  lint rem[2] = {x / g, y / g};
  lint sol = findMinX(num, rem, 2);
  lint prod = (lint) ma * (mb / g);
  sol = sol * g + (x % g);
  return sol % prod;
}


lint doit(const vector<pii> &data, int ape, int bpe, int p_gcd, int diff) {
  if (data.empty()) {
    return 0;
  }
  lint prod = (lint) ape * ((lint) bpe / p_gcd);
  vector<lint> tmp;
  for (int i = 0; i < (int) data.size(); ++i) {
    pii d = data[i];
    int x = d.first;
    int y = d.second;
    lint value = crt(x, (y + diff) % bpe, ape, bpe, p_gcd);
    tmp.push_back(value);
  }
  sort(tmp.begin(), tmp.end());
  tmp.push_back(tmp[0] + prod);
  lint total = 0;
  for (int i = 0; i < (int) tmp.size() - 1; ++i) {
    lint diff = tmp[i + 1] - tmp[i];
    diff %= MOD;
    total += diff * (diff - 1) / 2;
    total %= MOD;
  }
  return total;
}


int main(void) {
  int n;
  cin >> n;
  vector<int> a(n), b(n);
  for (int i = 0; i < n; ++i) {
    cin >> a[i];
    a[i]--;
  }
  for (int i = 0; i < n; ++i) {
    cin >> b[i];
    b[i]--;
  }
  vector<int> ainv = inverse(a), binv = inverse(b);
  vector<int> acycle, aperiod, aorder;
  vector<int> bcycle, bperiod, border;
  cycles(ainv, acycle, aperiod, aorder);
  cycles(binv, bcycle, bperiod, border);
  map<pii, vector<int> > data;
  for (int i = 0; i < n; ++i) {
    pii now(acycle[i], bcycle[i]);
    data[now].push_back(i);
  }
  lint total = 0;
  for (map<pii, vector<int> >::iterator it = data.begin(); it != data.end(); ++it) {
    pii cycles = it->first;
    vector<int> indices = it->second;
    int acy = cycles.first;
    int bcy = cycles.second;
    int p_gcd = gcd(aperiod[acy], bperiod[bcy]);
    vector<vector<pii> > data(p_gcd);
    for (int i = 0; i < (int) indices.size(); ++i) {
      int index = indices[i];
      int aord = aorder[index];
      int bord = border[index];
      int diff = (aord - bord) % p_gcd;
      if (diff < 0) {
	diff += p_gcd;
      }
      data[diff].push_back(pii(aord, bord));
    }
    for (int i = 0; i < p_gcd; ++i) {
      total += doit(data[i], aperiod[acy], bperiod[bcy], p_gcd, i);
      total %= MOD;
    }
  }
  cout << total << endl;
}
0