結果

問題 No.1978 Permutation Repetition
ユーザー shino16shino16
提出日時 2022-06-10 23:00:57
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
CE  
(最新)
AC  
(最初)
実行時間 -
コード長 4,011 bytes
コンパイル時間 2,095 ms
コンパイル使用メモリ 215,556 KB
最終ジャッジ日時 2024-11-15 02:18:56
合計ジャッジ時間 2,551 ms
ジャッジサーバーID
(参考情報)
judge2 / judge5
このコードへのチャレンジ
(要ログイン)
コンパイルエラー時のメッセージ・ソースコードは、提出者また管理者しか表示できないようにしております。(リジャッジ後のコンパイルエラーは公開されます)
ただし、clay言語の場合は開発者のデバッグのため、公開されます。

コンパイルメッセージ
main.cpp: In function 'std::vector<atcoder::static_modint<1000000007> > exp(std::vector<atcoder::static_modint<1000000007> >, int)':
main.cpp:80:35: error: 'ceil_pow2' is not a member of 'atcoder::internal'
   80 |   rep(i, a.size()) al[i] = a[i].val();
      |                                   ^~~~     

ソースコード

diff #

#line 2 "lib/prelude.hpp"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define rep2(i, m, n) for (auto i = (m); i < (n); i++)
#define rep(i, n) rep2(i, 0, n)
#define repr2(i, m, n) for (auto i = (n); i-- > (m);)
#define repr(i, n) repr2(i, 0, n)
#define all(x) begin(x), end(x)
#line 2 "main.cpp"
#include <atcoder/convolution>
using mint = atcoder::modint1000000007;

#define sz(x) int(size(x))
using vi = vector<int>;

typedef complex<double> C;
typedef vector<double> vd;
void fft(vector<C>& a) {
int n = sz(a), L = 31 - __builtin_clz(n);
static vector<complex<long double>> R(2, 1);
static vector<C> rt(2, 1); // (^ 10% fas te r i f double)
for (static int k = 2; k < n; k *= 2) {
R.resize(n); rt.resize(n);
auto x = polar(1.0L, acos(-1.0L) / k);
rep2(i,k,2*k) rt[i] = R[i] = i&1 ? R[i/2] * x : R[i/2];
}
vi rev(n);
rep2(i,0,n) rev[i] = (rev[i / 2] | (i & 1) << L) / 2;
rep2(i,0,n) if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int k = 1; k < n; k *= 2)
for (int i = 0; i < n; i += 2 * k) rep2(j,0,k) {
C z = rt[j+k] * a[i+j+k]; // (25% fas te r i f hand−ro l led )
a[i + j + k] = a[i + j] - z;
a[i + j] += z;
}
}
vd conv(const vd& a, const vd& b) {
if (a.empty() || b.empty()) return {};
vd res(sz(a) + sz(b) - 1);
int L = 32 - __builtin_clz(sz(res)), n = 1 << L;
vector<C> in(n), out(n);
copy(all(a), begin(in));
rep2(i,0,sz(b)) in[i].imag(b[i]);
fft(in);
for (C& x : in) x *= x;
rep2(i,0,n) out[i] = in[-i & (n - 1)] - conj(in[i]);
fft(out);
rep2(i,0,sz(res)) res[i] = imag(out[i]) / (4 * n);
return res;
}
typedef vector<ll> vl;
template<int M> vl convMod(const vl &a, const vl &b) {
if (a.empty() || b.empty()) return {};
vl res(sz(a) + sz(b) - 1);
int B=32-__builtin_clz(sz(res)), n=1<<B, cut=int(sqrt(M));
vector<C> L(n), R(n), outs(n), outl(n);
rep2(i,0,sz(a)) L[i] = C((int)a[i] / cut, (int)a[i] % cut);
rep2(i,0,sz(b)) R[i] = C((int)b[i] / cut, (int)b[i] % cut);
fft(L), fft(R);
rep2(i,0,n) {
int j = -i & (n - 1);
outl[j] = (L[i] + conj(L[j])) * R[i] / (2.0 * n);
outs[j] = (L[i] - conj(L[j])) * R[i] / (2.0 * n) / 1i;
}
fft(outl), fft(outs);
rep2(i,0,sz(res)) {
ll av = ll(real(outl[i])+.5), cv = ll(imag(outs[i])+.5);
ll bv = ll(imag(outl[i])+.5) + ll(real(outs[i])+.5);
res[i] = ((av % M * cut + bv) % M * cut + cv) % M;
}
return res;
}

vector<mint> conv(vector<mint> a, vector<mint> b, int deg = -1) {
  if (~deg)
    a.resize(min<int>(a.size(), deg)),
      b.resize(min<int>(b.size(), deg));
  vector<ll> al(a.size()), bl(b.size());
  rep(i, a.size()) al[i] = a[i].val();
  rep(i, b.size()) bl[i] = b[i].val();
  auto f = convMod<1000000007>(al, bl);
  if (~deg) f.resize(deg);
  return vector<mint>(all(f));
}

vector<mint> exp(vector<mint> h, int deg) {
  assert(h[0] == 0);
  int n = 1 << atcoder::internal::ceil_pow2(deg);
  h.resize(n);
  vector<mint> f = {1}, g = {1};
  vector<mint> q(h.size() - 1);
  rep(i, q.size()) q[i] = h[i + 1] * (i + 1);
  for (int m = 1; m < n; m *= 2) {
    auto fgg = conv(conv(f, g, m), g, m);
    g.resize(m);
    rep(i, m) g[i] = g[i] * 2 - fgg[i];
    auto fq = conv(f, {q.begin(), q.begin() + m - 1});
    rep(i, fq.size()) fq[i] =
      -fq[i] + (i < f.size() - 1 ? f[i + 1] * (i + 1) : 0);
    auto w = conv(g, fq, m * 2 - 1);
    rep(i, m - 1) w[i] += q[i];
    w.push_back(0);
    repr(i, m * 2 - 1) w[i + 1] = h[i + 1] - w[i] / (i + 1);
    w[0] = h[0];
    auto fw = conv(f, w, m * 2);
    f.resize(m * 2);
    rep(i, m * 2) f[i] += fw[i];
  }
  return f;
}

int N, M, A[1000];
bool used[1000];
int cycles[1001];

int main() {
  scanf("%d%d", &N, &M);
  rep(i, N) scanf("%d", A+i), A[i]--;

  rep(i, N) if (!used[i]) {
    int cnt = 0;
    for (int v = i; !used[v]; v = A[v])
      used[v] = true, cnt++;
    cycles[cnt]++;
  }

  mint ans = 1;
  rep(m, N+1) if (cycles[m] != 0) {
    int l = cycles[m];
    vector<mint> poly(l+1);
    rep2(k, 1, l+1) if (gcd(M, k * m) == k) {
      poly[k] = mint(m).pow(k-1) / k;
    }
    ans *= exp(move(poly), l+1)[l];
    rep(i, l) ans *= i+1;
  }

  printf("%d\n", ans.val());
}
0