結果

問題 No.590 Replacement
ユーザー 夕叢霧香(ゆうむらきりか)夕叢霧香(ゆうむらきりか)
提出日時 2017-11-04 00:19:50
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
WA  
実行時間 -
コード長 4,946 bytes
コンパイル時間 1,634 ms
コンパイル使用メモリ 104,736 KB
実行使用メモリ 18,048 KB
最終ジャッジ日時 2024-05-02 20:59:38
合計ジャッジ時間 5,267 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
5,248 KB
testcase_01 AC 2 ms
5,376 KB
testcase_02 AC 2 ms
5,376 KB
testcase_03 WA -
testcase_04 WA -
testcase_05 WA -
testcase_06 WA -
testcase_07 WA -
testcase_08 WA -
testcase_09 WA -
testcase_10 WA -
testcase_11 WA -
testcase_12 WA -
testcase_13 WA -
testcase_14 WA -
testcase_15 WA -
testcase_16 WA -
testcase_17 WA -
testcase_18 WA -
testcase_19 WA -
testcase_20 WA -
testcase_21 WA -
testcase_22 WA -
testcase_23 WA -
testcase_24 WA -
testcase_25 WA -
testcase_26 WA -
testcase_27 WA -
testcase_28 WA -
testcase_29 WA -
testcase_30 WA -
testcase_31 WA -
testcase_32 WA -
testcase_33 AC 2 ms
5,376 KB
testcase_34 AC 2 ms
5,376 KB
testcase_35 AC 2 ms
5,376 KB
testcase_36 WA -
testcase_37 WA -
testcase_38 AC 83 ms
11,848 KB
testcase_39 AC 82 ms
11,848 KB
testcase_40 AC 92 ms
11,664 KB
testcase_41 AC 92 ms
11,540 KB
testcase_42 AC 81 ms
10,804 KB
testcase_43 AC 80 ms
9,636 KB
testcase_44 AC 124 ms
18,048 KB
testcase_45 WA -
testcase_46 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <algorithm>
#include <cassert>
#include <iostream>
#include <map>
#include <vector>

using namespace std;
typedef long long lint;
typedef pair<int, int> pii;

const lint MOD = 1000000007;
const bool DEBUG = false;


int gcd(int x, int y) {
  while (y != 0) {
    int remainder = x % y;
    x = y;
    y = remainder;
  }
  return x;
}

vector<int> inverse(const vector<int> &permutation) {
  int n = permutation.size();
  vector<int> result(n);
  for (int i = 0; i < n; ++i) {
    result[permutation[i]] = i;
  }
  return result;
}

void cycles(const vector<int> &permutation, vector<int> &cycle_id, vector<int> &period, vector<int> &order) {
  int n = permutation.size();
  cycle_id = vector<int>(n, -1);
  order = vector<int>(n);
  vector<int> periods;
  int count = 0;
  for (int i = 0; i < n; ++i) {
    if (cycle_id[i] >= 0) {
      continue;
    }
    cycle_id[i] = count;
    order[i] = 0;
    int p = 1;
    int current = permutation[i];
    while (current != i) {
      cycle_id[current] = count;
      order[current] = p;
      current = permutation[current];
      p++;
    }
    periods.push_back(p);
    count++;
  }
  period = periods;
}


// http://techtipshoge.blogspot.jp/2015/02/blog-post_15.html
void extgcd(int a, int b, int &fst, int &snd) {
  if (b == 0) {
    fst = 1;
    snd = 0;
    return;
  }
  extgcd(b, a % b, fst, snd);
  int new_snd = fst - (a / b) * snd;
  fst = snd;
  snd = new_snd;
}

lint crt(int x, int y, int ma, int mb, int g) {
  int p, q;
  ma /= g;
  mb /= g;
  extgcd(ma, mb, p, q);
  p %= ma;
  if (p < 0) {
    p += ma;
  }
  q %= mb;
  if (q < 0) {
    q += mb;
  }
  lint r = x % g;
  x /= g;
  y /= g;
  lint value = (lint) x * q * mb + (lint) y * p * ma;
  lint prod = (lint) ma * mb * g;
  value = (value * g + r);
  return value % prod;
}


lint doit(const vector<pii> &data, int ape, int bpe, int p_gcd, int diff) {
  if (data.empty()) {
    return 0;
  }
  if (DEBUG) {
    cerr << "doit:" << ape << " " << bpe << endl;
    for (auto v: data) {
      cerr << v.first << " " << v.second << endl;
    }
  }
  lint prod = (lint) ape * (lint) bpe / p_gcd;
  vector<lint> tmp;
  for (int i = 0; i < (int) data.size(); ++i) {
    pii d = data[i];
    int x = d.first;
    int y = d.second;
    lint value = crt(x, (y + diff) % bpe, ape, bpe, p_gcd);
    tmp.push_back(value);
  }
  sort(tmp.begin(), tmp.end());
  tmp.push_back(tmp[0] + prod);
  if (DEBUG) {
    cerr << "tmp:";
    for (auto v: tmp) {
      cerr << " " << v;
    }
    cerr << endl;
  }
  lint total = 0;
  for (int i = 0; i < (int) tmp.size() - 1; ++i) {
    lint diff = tmp[i + 1] - tmp[i];
    total += diff * (diff - 1) / 2;
    total %= MOD;
  }
  if (DEBUG) {
    cerr << "return: " << total << endl;
  }
  return total;
}


int main(void) {
  int n;
  cin >> n;
  vector<int> a(n), b(n);
  for (int i = 0; i < n; ++i) {
    cin >> a[i];
    a[i]--;
  }
  for (int i = 0; i < n; ++i) {
    cin >> b[i];
    b[i]--;
  }
  vector<int> ainv = inverse(a), binv = inverse(b);
  vector<int> acycle, aperiod, aorder;
  vector<int> bcycle, bperiod, border;
  cycles(ainv, acycle, aperiod, aorder);
  cycles(binv, bcycle, bperiod, border);
  if (DEBUG) {
    cerr << "acycle: ";
    for (int i = 0; i < n; ++i) {
      cerr << " " << acycle[i];
    }
    cerr << endl;
    cerr << "bcycle: ";
    for (int i = 0; i < n; ++i) {
      cerr << " " << bcycle[i];
    }
    cerr << endl;
  }
  map<pii, vector<int> > data;
  for (int i = 0; i < n; ++i) {
    pii now(acycle[i], bcycle[i]);
    data[now].push_back(i);
  }
  if (DEBUG) {
    for (auto value: data) {
      cerr << "data ";
      cerr << value.first.first << " " << value.first.second << ":";
      for (auto p: value.second) {
	cerr << " " << p;
      }
      cerr << endl;
    }
  }
  lint total = 0;
  for (map<pii, vector<int> >::iterator it = data.begin(); it != data.end(); ++it) {
    pii cycles = it->first;
    vector<int> indices = it->second;
    int acy = cycles.first;
    int bcy = cycles.second;
    int p_gcd = gcd(aperiod[acy], bperiod[bcy]);
    if (DEBUG) {
      cerr << "acy, bcy = " << acy << " " << bcy << ", pgcd = " << p_gcd << endl;
    }
    vector<vector<pii> > data(p_gcd);
    for (int i = 0; i < (int) indices.size(); ++i) {
      int index = indices[i];
      int aord = aorder[index];
      int bord = border[index];
      int diff = (aord - bord) % p_gcd;
      if (diff < 0) {
	diff += p_gcd;
      }
      data[diff].push_back(pii(aord, bord));
    }
    for (int i = 0; i < p_gcd; ++i) {
      total += doit(data[i], aperiod[acy], bperiod[bcy], p_gcd, i);
      total %= MOD;
    }
  }
#if 0
  for (int i = 0; i < n; ++i) {
    // (i, i) no gyaku
    lint count = 0;
    int ai = i;
    int bi = i;
    while (true) {
      if (count > 0 && ai == bi) {
	break;
      }
      ai = ainv[ai];
      bi = binv[bi];
      count++;
    }
    total = (total + count * (count - 1) / 2) % MOD;
    
  }
#endif
  cout << total << endl;
}
0