結果

問題 No.1392 Don't be together
ユーザー Ryuhei MoriRyuhei Mori
提出日時 2021-02-17 22:17:35
言語 C++17
(gcc 13.2.0 + boost 1.83.0)
結果
AC  
実行時間 3 ms / 2,000 ms
コード長 3,843 bytes
コンパイル時間 669 ms
コンパイル使用メモリ 65,728 KB
実行使用メモリ 4,356 KB
最終ジャッジ日時 2023-10-12 08:11:59
合計ジャッジ時間 1,921 ms
ジャッジサーバーID
(参考情報)
judge11 / judge12
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
4,348 KB
testcase_01 AC 2 ms
4,348 KB
testcase_02 AC 1 ms
4,352 KB
testcase_03 AC 1 ms
4,352 KB
testcase_04 AC 1 ms
4,352 KB
testcase_05 AC 1 ms
4,348 KB
testcase_06 AC 2 ms
4,356 KB
testcase_07 AC 3 ms
4,352 KB
testcase_08 AC 2 ms
4,348 KB
testcase_09 AC 2 ms
4,348 KB
testcase_10 AC 2 ms
4,352 KB
testcase_11 AC 3 ms
4,348 KB
testcase_12 AC 2 ms
4,352 KB
testcase_13 AC 2 ms
4,352 KB
testcase_14 AC 3 ms
4,352 KB
testcase_15 AC 2 ms
4,348 KB
testcase_16 AC 3 ms
4,348 KB
testcase_17 AC 2 ms
4,352 KB
testcase_18 AC 2 ms
4,348 KB
testcase_19 AC 3 ms
4,352 KB
testcase_20 AC 2 ms
4,352 KB
testcase_21 AC 2 ms
4,352 KB
testcase_22 AC 2 ms
4,348 KB
testcase_23 AC 2 ms
4,352 KB
testcase_24 AC 2 ms
4,352 KB
testcase_25 AC 2 ms
4,352 KB
testcase_26 AC 2 ms
4,352 KB
testcase_27 AC 2 ms
4,348 KB
testcase_28 AC 2 ms
4,352 KB
testcase_29 AC 2 ms
4,348 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <cstdio>
#include <vector>
#include <algorithm>

using u32 = unsigned int;
using u64 = unsigned long long int;

u32 n, m;
std::vector<u32> p;

/*
A =
[0 1 1 1]
[1 0 1 1]
[1 1 0 1]
[1 1 1 0]

tr(A^k)
= tr((1 - I)^k)
= \sum_{i=0}^k \binom{k}{i} (-1)^{k-i} tr(1^i)
= \sum_{i=0}^k \binom{k}{i} (-1)^{k-i} tr(m^{i-1} 1)
= (-1)^k * m + \sum_{i=1}^k \binom{k}{i} (-1)^{k-i} m^i
= (m - 1)^k + (-1)^k * (m - 1)

*/

template <u32 MOD>
struct Mint {
  u32 n;
  constexpr Mint(u32 n = 0): n(n) {}
  constexpr Mint operator-() const { return Mint(n ? MOD - n: 0); }
  constexpr Mint &operator+=(const Mint &rhs){ n += rhs.n; if(n >= MOD) n -= MOD; return *this; }
  constexpr Mint &operator-=(const Mint &rhs){ if(rhs.n > n) n += MOD; n -= rhs.n; return *this; }
  constexpr Mint &operator*=(const Mint &rhs){ n = (u64) n * rhs.n % MOD; return *this; }
  friend constexpr Mint operator+(const Mint &lhs, const Mint &rhs){ return Mint(lhs) += rhs; }
  friend constexpr Mint operator-(const Mint &lhs, const Mint &rhs){ return Mint(lhs) -= rhs; }
  friend constexpr Mint operator*(const Mint &lhs, const Mint &rhs){ return Mint(lhs) *= rhs; }
  friend constexpr bool operator==(const Mint &lhs, const Mint &rhs){ return lhs.n == rhs.n; }
  friend constexpr bool operator!=(const Mint &lhs, const Mint &rhs){ return lhs.n != rhs.n; }
};

template <class T>
T mypow(T a, u32 n){
  T r = 1;
  for(; n; n >>= 1){
    if(n&1) r *= a;
    a *= a;
  }
  return r;
}

template <u32 MOD>
Mint<MOD> inv(Mint<MOD> a){
  return mypow(a, MOD-2);
}

constexpr u32 mod = 998244353;
using mint = Mint<mod>;


std::vector<u32> perm(const std::vector<u32> &p){
  std::vector<u32> q = p;
  std::vector<u32> cycles;
  for(u32 i = 0; i < q.size(); i++){
    if(q[i] == -1U) continue;
    u32 c = 1;
    u32 x = q[i];
    q[i] = -1;
    while(q[x] != -1U){
      u32 z = q[x];
      q[x] = -1;
      x = z;
      c++;
    }
    cycles.push_back(c);
  }
  return cycles;
}

template<typename T>
std::vector<std::pair<T, u32> > countv(const std::vector<T> &x){
  auto y = x;
  std::vector<std::pair<T, u32> > r;
  std::sort(std::begin(y), std::end(y));
  u32 cnt = 1;
  T prev = y[0];
  for(u32 i = 1; i < y.size(); i++){
    if(y[i] == prev) cnt++;
    else {
      r.emplace_back(prev, cnt);
      cnt = 1;
      prev = y[i];
    }
  }
  r.emplace_back(prev, cnt);
  return r;
}

int main(){
  scanf("%d%d", &n, &m);
  for(u32 i = 0; i < n; i++){
    int x;
    scanf("%d", &x);
    p.push_back(x-1);
  }

  std::vector<u32> cycles = perm(p);

  std::vector<mint> v(m+1, 1);

  std::vector<u32> primes;
  std::vector<u32> min_factors(m);
  for(u32 i = 2; i < m; i++){
    if(min_factors[i] == 0){
      primes.push_back(i);
      min_factors[i] = i;
    }
    for(u32 p: primes){
      if(p * i >= m || p > min_factors[i]) break;
      min_factors[p * i] = p;
    }
  }

  auto cyc = countv(cycles);

  for(auto [c, k]: cyc){
    std::vector<mint> powc(m);
    powc[1] = 1;
    for(u32 i = 2; i < m; i++){
      if(min_factors[i] == i) powc[i] = mypow(mint(i), c);
      else powc[i] = powc[min_factors[i]] * powc[i / min_factors[i]];
    }
    for(u32 i = 2; i <= m; i++){
      if(c&1)
        v[i] *= mypow(powc[i-1] - mint(i-1), k);
//        v[i] *= mypow(mint(i-1), c) - mint(i-1);
      else
        v[i] *= mypow(powc[i-1] + mint(i-1), k);
//        v[i] *= mypow(mint(i-1), c) + mint(i-1);
    }
  }

  std::vector<mint> fact(m+1);
  fact[0] = 1;
  for(u32 i = 1; i <= m; i++) fact[i] = fact[i-1] * i;
  std::vector<mint> ifact(m+1);
  ifact[m] = inv(fact[m]);
  for(int i = m-1; i >= 0; i--) ifact[i] = ifact[i+1] * (i+1);

  mint ans = 0;
  for(u32 i = 2; i <= m; i++){
    if((m-i)&1)
      ans -= fact[m] * ifact[i] * ifact[m-i] * v[i];
    else
      ans += fact[m] * ifact[i] * ifact[m-i] * v[i];
  }

  printf("%d\n", (ans * ifact[m]).n);

  return 0;
}
0