結果
| 問題 |
No.590 Replacement
|
| コンテスト | |
| ユーザー |
夕叢霧香(ゆうむらきりか)
|
| 提出日時 | 2017-11-04 00:32:34 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 152 ms / 2,000 ms |
| コード長 | 6,000 bytes |
| コンパイル時間 | 1,647 ms |
| コンパイル使用メモリ | 99,856 KB |
| 最終ジャッジ日時 | 2025-01-05 03:47:51 |
|
ジャッジサーバーID (参考情報) |
judge5 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 47 |
ソースコード
#include <algorithm>
#include <cassert>
#include <iostream>
#include <map>
#include <vector>
using namespace std;
typedef long long lint;
typedef pair<int, int> pii;
const lint MOD = 1000000007;
const bool DEBUG = false;
int gcd(int x, int y) {
while (y != 0) {
int remainder = x % y;
x = y;
y = remainder;
}
return x;
}
vector<int> inverse(const vector<int> &permutation) {
int n = permutation.size();
vector<int> result(n);
for (int i = 0; i < n; ++i) {
result[permutation[i]] = i;
}
return result;
}
void cycles(const vector<int> &permutation, vector<int> &cycle_id, vector<int> &period, vector<int> &order) {
int n = permutation.size();
cycle_id = vector<int>(n, -1);
order = vector<int>(n);
vector<int> periods;
int count = 0;
for (int i = 0; i < n; ++i) {
if (cycle_id[i] >= 0) {
continue;
}
cycle_id[i] = count;
order[i] = 0;
int p = 1;
int current = permutation[i];
while (current != i) {
cycle_id[current] = count;
order[current] = p;
current = permutation[current];
p++;
}
periods.push_back(p);
count++;
}
period = periods;
}
// http://www.geeksforgeeks.org/chinese-remainder-theorem-set-2-implementation/
// Returns modulo inverse of a with respect to m using extended
// Euclid Algorithm. Refer below post for details:
// http://www.geeksforgeeks.org/multiplicative-inverse-under-modulo-m/
lint inv(lint a, lint m)
{
lint m0 = m, t, q;
lint x0 = 0, x1 = 1;
if (m == 1)
return 0;
// Apply extended Euclid Algorithm
while (a > 1)
{
// q is quotient
q = a / m;
t = m;
// m is remainder now, process same as
// euclid's algo
m = a % m, a = t;
t = x0;
x0 = x1 - q * x0;
x1 = t;
}
// Make x1 positive
if (x1 < 0)
x1 += m0;
return x1;
}
// k is size of num[] and rem[]. Returns the smallest
// number x such that:
// x % num[0] = rem[0],
// x % num[1] = rem[1],
// ..................
// x % num[k-2] = rem[k-1]
// Assumption: Numbers in num[] are pairwise coprime
// (gcd for every pair is 1)
lint findMinX(lint num[], lint rem[], int k)
{
// Compute product of all numbers
lint prod = 1;
for (int i = 0; i < k; i++)
prod *= num[i];
// Initialize result
lint result = 0;
// Apply above formula
for (int i = 0; i < k; i++)
{
lint pp = prod / num[i];
result += rem[i] * inv(pp, num[i]) * pp;
}
return result % prod;
}
lint crt(int x, int y, int ma, int mb, int g) {
lint num[2] = {ma / g, mb / g};
lint rem[2] = {x / g, y / g};
lint sol = findMinX(num, rem, 2);
lint value = sol;
lint prod = (lint) ma * (mb / g);
value = (value * g + (x % g));
return value % prod;
}
lint doit(const vector<pii> &data, int ape, int bpe, int p_gcd, int diff) {
if (data.empty()) {
return 0;
}
if (DEBUG) {
cerr << "doit:" << ape << " " << bpe << endl;
for (auto v: data) {
cerr << v.first << " " << v.second << endl;
}
}
lint prod = (lint) ape * ((lint) bpe / p_gcd);
vector<lint> tmp;
for (int i = 0; i < (int) data.size(); ++i) {
pii d = data[i];
int x = d.first;
int y = d.second;
lint value = crt(x, (y + diff) % bpe, ape, bpe, p_gcd);
tmp.push_back(value);
}
sort(tmp.begin(), tmp.end());
tmp.push_back(tmp[0] + prod);
if (DEBUG) {
cerr << "tmp:";
for (auto v: tmp) {
cerr << " " << v;
}
cerr << endl;
}
lint total = 0;
for (int i = 0; i < (int) tmp.size() - 1; ++i) {
lint diff = tmp[i + 1] - tmp[i];
diff %= MOD;
total += diff * (diff - 1) / 2;
total %= MOD;
}
if (DEBUG) {
cerr << "return: " << total << endl;
}
return total;
}
int main(void) {
int n;
cin >> n;
vector<int> a(n), b(n);
for (int i = 0; i < n; ++i) {
cin >> a[i];
a[i]--;
}
for (int i = 0; i < n; ++i) {
cin >> b[i];
b[i]--;
}
vector<int> ainv = inverse(a), binv = inverse(b);
vector<int> acycle, aperiod, aorder;
vector<int> bcycle, bperiod, border;
cycles(ainv, acycle, aperiod, aorder);
cycles(binv, bcycle, bperiod, border);
if (DEBUG) {
cerr << "acycle: ";
for (int i = 0; i < n; ++i) {
cerr << " " << acycle[i];
}
cerr << endl;
cerr << "bcycle: ";
for (int i = 0; i < n; ++i) {
cerr << " " << bcycle[i];
}
cerr << endl;
}
map<pii, vector<int> > data;
for (int i = 0; i < n; ++i) {
pii now(acycle[i], bcycle[i]);
data[now].push_back(i);
}
if (DEBUG) {
for (auto value: data) {
cerr << "data ";
cerr << value.first.first << " " << value.first.second << ":";
for (auto p: value.second) {
cerr << " " << p;
}
cerr << endl;
}
}
lint total = 0;
for (map<pii, vector<int> >::iterator it = data.begin(); it != data.end(); ++it) {
pii cycles = it->first;
vector<int> indices = it->second;
int acy = cycles.first;
int bcy = cycles.second;
int p_gcd = gcd(aperiod[acy], bperiod[bcy]);
if (DEBUG) {
cerr << "acy, bcy = " << acy << " " << bcy << ", pgcd = " << p_gcd << endl;
}
vector<vector<pii> > data(p_gcd);
for (int i = 0; i < (int) indices.size(); ++i) {
int index = indices[i];
int aord = aorder[index];
int bord = border[index];
int diff = (aord - bord) % p_gcd;
if (diff < 0) {
diff += p_gcd;
}
data[diff].push_back(pii(aord, bord));
}
for (int i = 0; i < p_gcd; ++i) {
total += doit(data[i], aperiod[acy], bperiod[bcy], p_gcd, i);
total %= MOD;
}
}
#if 0
for (int i = 0; i < n; ++i) {
// (i, i) no gyaku
lint count = 0;
int ai = i;
int bi = i;
while (true) {
if (count > 0 && ai == bi) {
break;
}
ai = ainv[ai];
bi = binv[bi];
count++;
}
total = (total + count * (count - 1) / 2) % MOD;
}
#endif
cout << total << endl;
}
夕叢霧香(ゆうむらきりか)