結果

問題 No.590 Replacement
ユーザー 夕叢霧香(ゆうむらきりか)夕叢霧香(ゆうむらきりか)
提出日時 2017-11-04 00:31:30
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 5,983 bytes
コンパイル時間 1,611 ms
コンパイル使用メモリ 99,740 KB
最終ジャッジ日時 2025-01-05 03:47:39
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 3 ms
6,816 KB
testcase_01 AC 2 ms
6,820 KB
testcase_02 AC 2 ms
6,816 KB
testcase_03 AC 2 ms
6,820 KB
testcase_04 AC 2 ms
6,816 KB
testcase_05 AC 2 ms
6,816 KB
testcase_06 AC 2 ms
6,820 KB
testcase_07 AC 10 ms
6,820 KB
testcase_08 AC 6 ms
6,820 KB
testcase_09 AC 10 ms
6,820 KB
testcase_10 AC 6 ms
6,816 KB
testcase_11 AC 4 ms
6,816 KB
testcase_12 AC 9 ms
6,820 KB
testcase_13 AC 43 ms
6,816 KB
testcase_14 AC 62 ms
6,820 KB
testcase_15 AC 6 ms
6,820 KB
testcase_16 AC 18 ms
6,816 KB
testcase_17 AC 28 ms
6,820 KB
testcase_18 AC 100 ms
7,632 KB
testcase_19 AC 106 ms
6,816 KB
testcase_20 AC 25 ms
6,820 KB
testcase_21 AC 122 ms
7,696 KB
testcase_22 AC 120 ms
8,924 KB
testcase_23 AC 128 ms
8,180 KB
testcase_24 AC 143 ms
7,892 KB
testcase_25 AC 138 ms
7,680 KB
testcase_26 AC 133 ms
7,684 KB
testcase_27 AC 133 ms
8,280 KB
testcase_28 AC 133 ms
7,296 KB
testcase_29 AC 131 ms
7,168 KB
testcase_30 AC 134 ms
7,040 KB
testcase_31 AC 130 ms
8,020 KB
testcase_32 AC 133 ms
8,120 KB
testcase_33 AC 2 ms
6,820 KB
testcase_34 AC 1 ms
6,816 KB
testcase_35 AC 3 ms
6,824 KB
testcase_36 AC 104 ms
9,632 KB
testcase_37 AC 101 ms
9,504 KB
testcase_38 AC 94 ms
11,764 KB
testcase_39 AC 99 ms
11,976 KB
testcase_40 AC 122 ms
11,408 KB
testcase_41 AC 120 ms
11,408 KB
testcase_42 AC 95 ms
10,812 KB
testcase_43 AC 102 ms
9,636 KB
testcase_44 AC 149 ms
18,056 KB
testcase_45 WA -
testcase_46 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <algorithm>
#include <cassert>
#include <iostream>
#include <map>
#include <vector>

using namespace std;
typedef long long lint;
typedef pair<int, int> pii;

const lint MOD = 1000000007;
const bool DEBUG = false;


int gcd(int x, int y) {
  while (y != 0) {
    int remainder = x % y;
    x = y;
    y = remainder;
  }
  return x;
}

vector<int> inverse(const vector<int> &permutation) {
  int n = permutation.size();
  vector<int> result(n);
  for (int i = 0; i < n; ++i) {
    result[permutation[i]] = i;
  }
  return result;
}

void cycles(const vector<int> &permutation, vector<int> &cycle_id, vector<int> &period, vector<int> &order) {
  int n = permutation.size();
  cycle_id = vector<int>(n, -1);
  order = vector<int>(n);
  vector<int> periods;
  int count = 0;
  for (int i = 0; i < n; ++i) {
    if (cycle_id[i] >= 0) {
      continue;
    }
    cycle_id[i] = count;
    order[i] = 0;
    int p = 1;
    int current = permutation[i];
    while (current != i) {
      cycle_id[current] = count;
      order[current] = p;
      current = permutation[current];
      p++;
    }
    periods.push_back(p);
    count++;
  }
  period = periods;
}

// http://www.geeksforgeeks.org/chinese-remainder-theorem-set-2-implementation/

// Returns modulo inverse of a with respect to m using extended
// Euclid Algorithm. Refer below post for details:
// http://www.geeksforgeeks.org/multiplicative-inverse-under-modulo-m/
lint inv(lint a, lint m)
{
    lint m0 = m, t, q;
    lint x0 = 0, x1 = 1;
 
    if (m == 1)
       return 0;
 
    // Apply extended Euclid Algorithm
    while (a > 1)
    {
        // q is quotient
        q = a / m;
 
        t = m;
 
        // m is remainder now, process same as
        // euclid's algo
        m = a % m, a = t;
 
        t = x0;
 
        x0 = x1 - q * x0;
 
        x1 = t;
    }
 
    // Make x1 positive
    if (x1 < 0)
       x1 += m0;
 
    return x1;
}
 
// k is size of num[] and rem[].  Returns the smallest
// number x such that:
//  x % num[0] = rem[0],
//  x % num[1] = rem[1],
//  ..................
//  x % num[k-2] = rem[k-1]
// Assumption: Numbers in num[] are pairwise coprime
// (gcd for every pair is 1)
lint findMinX(lint num[], lint rem[], int k)
{
    // Compute product of all numbers
    lint prod = 1;
    for (int i = 0; i < k; i++)
        prod *= num[i];
 
    // Initialize result
    lint result = 0;
 
    // Apply above formula
    for (int i = 0; i < k; i++)
    {
        lint pp = prod / num[i];
        result += rem[i] * inv(pp, num[i]) * pp;
    }
 
    return result % prod;
}


lint crt(int x, int y, int ma, int mb, int g) {
  lint num[2] = {ma / g, mb / g};
  lint rem[2] = {x / g, y / g};
  lint sol = findMinX(num, rem, 2);
  lint value = sol;
  lint prod = (lint) ma * (mb / g);
  value = (value * g + (x % g));
  return value % prod;
}


lint doit(const vector<pii> &data, int ape, int bpe, int p_gcd, int diff) {
  if (data.empty()) {
    return 0;
  }
  if (DEBUG) {
    cerr << "doit:" << ape << " " << bpe << endl;
    for (auto v: data) {
      cerr << v.first << " " << v.second << endl;
    }
  }
  lint prod = (lint) ape * ((lint) bpe / p_gcd);
  vector<lint> tmp;
  for (int i = 0; i < (int) data.size(); ++i) {
    pii d = data[i];
    int x = d.first;
    int y = d.second;
    lint value = crt(x, (y + diff) % bpe, ape, bpe, p_gcd);
    tmp.push_back(value);
  }
  sort(tmp.begin(), tmp.end());
  tmp.push_back(tmp[0] + prod);
  if (DEBUG) {
    cerr << "tmp:";
    for (auto v: tmp) {
      cerr << " " << v;
    }
    cerr << endl;
  }
  lint total = 0;
  for (int i = 0; i < (int) tmp.size() - 1; ++i) {
    lint diff = tmp[i + 1] - tmp[i];
    total += diff * (diff - 1) / 2;
    total %= MOD;
  }
  if (DEBUG) {
    cerr << "return: " << total << endl;
  }
  return total;
}


int main(void) {
  int n;
  cin >> n;
  vector<int> a(n), b(n);
  for (int i = 0; i < n; ++i) {
    cin >> a[i];
    a[i]--;
  }
  for (int i = 0; i < n; ++i) {
    cin >> b[i];
    b[i]--;
  }
  vector<int> ainv = inverse(a), binv = inverse(b);
  vector<int> acycle, aperiod, aorder;
  vector<int> bcycle, bperiod, border;
  cycles(ainv, acycle, aperiod, aorder);
  cycles(binv, bcycle, bperiod, border);
  if (DEBUG) {
    cerr << "acycle: ";
    for (int i = 0; i < n; ++i) {
      cerr << " " << acycle[i];
    }
    cerr << endl;
    cerr << "bcycle: ";
    for (int i = 0; i < n; ++i) {
      cerr << " " << bcycle[i];
    }
    cerr << endl;
  }
  map<pii, vector<int> > data;
  for (int i = 0; i < n; ++i) {
    pii now(acycle[i], bcycle[i]);
    data[now].push_back(i);
  }
  if (DEBUG) {
    for (auto value: data) {
      cerr << "data ";
      cerr << value.first.first << " " << value.first.second << ":";
      for (auto p: value.second) {
	cerr << " " << p;
      }
      cerr << endl;
    }
  }
  lint total = 0;
  for (map<pii, vector<int> >::iterator it = data.begin(); it != data.end(); ++it) {
    pii cycles = it->first;
    vector<int> indices = it->second;
    int acy = cycles.first;
    int bcy = cycles.second;
    int p_gcd = gcd(aperiod[acy], bperiod[bcy]);
    if (DEBUG) {
      cerr << "acy, bcy = " << acy << " " << bcy << ", pgcd = " << p_gcd << endl;
    }
    vector<vector<pii> > data(p_gcd);
    for (int i = 0; i < (int) indices.size(); ++i) {
      int index = indices[i];
      int aord = aorder[index];
      int bord = border[index];
      int diff = (aord - bord) % p_gcd;
      if (diff < 0) {
	diff += p_gcd;
      }
      data[diff].push_back(pii(aord, bord));
    }
    for (int i = 0; i < p_gcd; ++i) {
      total += doit(data[i], aperiod[acy], bperiod[bcy], p_gcd, i);
      total %= MOD;
    }
  }
#if 0
  for (int i = 0; i < n; ++i) {
    // (i, i) no gyaku
    lint count = 0;
    int ai = i;
    int bi = i;
    while (true) {
      if (count > 0 && ai == bi) {
	break;
      }
      ai = ainv[ai];
      bi = binv[bi];
      count++;
    }
    total = (total + count * (count - 1) / 2) % MOD;
    
  }
#endif
  cout << total << endl;
}
0