結果
| 問題 |
No.1392 Don't be together
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2021-02-17 22:17:35 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 4 ms / 2,000 ms |
| コード長 | 3,843 bytes |
| コンパイル時間 | 915 ms |
| コンパイル使用メモリ | 66,968 KB |
| 最終ジャッジ日時 | 2025-01-18 21:54:56 |
|
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 27 |
コンパイルメッセージ
main.cpp: In function ‘int main()’:
main.cpp:100:8: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
100 | scanf("%d%d", &n, &m);
| ~~~~~^~~~~~~~~~~~~~~~
main.cpp:103:10: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
103 | scanf("%d", &x);
| ~~~~~^~~~~~~~~~
ソースコード
#include <cstdio>
#include <vector>
#include <algorithm>
using u32 = unsigned int;
using u64 = unsigned long long int;
u32 n, m;
std::vector<u32> p;
/*
A =
[0 1 1 1]
[1 0 1 1]
[1 1 0 1]
[1 1 1 0]
tr(A^k)
= tr((1 - I)^k)
= \sum_{i=0}^k \binom{k}{i} (-1)^{k-i} tr(1^i)
= \sum_{i=0}^k \binom{k}{i} (-1)^{k-i} tr(m^{i-1} 1)
= (-1)^k * m + \sum_{i=1}^k \binom{k}{i} (-1)^{k-i} m^i
= (m - 1)^k + (-1)^k * (m - 1)
*/
template <u32 MOD>
struct Mint {
u32 n;
constexpr Mint(u32 n = 0): n(n) {}
constexpr Mint operator-() const { return Mint(n ? MOD - n: 0); }
constexpr Mint &operator+=(const Mint &rhs){ n += rhs.n; if(n >= MOD) n -= MOD; return *this; }
constexpr Mint &operator-=(const Mint &rhs){ if(rhs.n > n) n += MOD; n -= rhs.n; return *this; }
constexpr Mint &operator*=(const Mint &rhs){ n = (u64) n * rhs.n % MOD; return *this; }
friend constexpr Mint operator+(const Mint &lhs, const Mint &rhs){ return Mint(lhs) += rhs; }
friend constexpr Mint operator-(const Mint &lhs, const Mint &rhs){ return Mint(lhs) -= rhs; }
friend constexpr Mint operator*(const Mint &lhs, const Mint &rhs){ return Mint(lhs) *= rhs; }
friend constexpr bool operator==(const Mint &lhs, const Mint &rhs){ return lhs.n == rhs.n; }
friend constexpr bool operator!=(const Mint &lhs, const Mint &rhs){ return lhs.n != rhs.n; }
};
template <class T>
T mypow(T a, u32 n){
T r = 1;
for(; n; n >>= 1){
if(n&1) r *= a;
a *= a;
}
return r;
}
template <u32 MOD>
Mint<MOD> inv(Mint<MOD> a){
return mypow(a, MOD-2);
}
constexpr u32 mod = 998244353;
using mint = Mint<mod>;
std::vector<u32> perm(const std::vector<u32> &p){
std::vector<u32> q = p;
std::vector<u32> cycles;
for(u32 i = 0; i < q.size(); i++){
if(q[i] == -1U) continue;
u32 c = 1;
u32 x = q[i];
q[i] = -1;
while(q[x] != -1U){
u32 z = q[x];
q[x] = -1;
x = z;
c++;
}
cycles.push_back(c);
}
return cycles;
}
template<typename T>
std::vector<std::pair<T, u32> > countv(const std::vector<T> &x){
auto y = x;
std::vector<std::pair<T, u32> > r;
std::sort(std::begin(y), std::end(y));
u32 cnt = 1;
T prev = y[0];
for(u32 i = 1; i < y.size(); i++){
if(y[i] == prev) cnt++;
else {
r.emplace_back(prev, cnt);
cnt = 1;
prev = y[i];
}
}
r.emplace_back(prev, cnt);
return r;
}
int main(){
scanf("%d%d", &n, &m);
for(u32 i = 0; i < n; i++){
int x;
scanf("%d", &x);
p.push_back(x-1);
}
std::vector<u32> cycles = perm(p);
std::vector<mint> v(m+1, 1);
std::vector<u32> primes;
std::vector<u32> min_factors(m);
for(u32 i = 2; i < m; i++){
if(min_factors[i] == 0){
primes.push_back(i);
min_factors[i] = i;
}
for(u32 p: primes){
if(p * i >= m || p > min_factors[i]) break;
min_factors[p * i] = p;
}
}
auto cyc = countv(cycles);
for(auto [c, k]: cyc){
std::vector<mint> powc(m);
powc[1] = 1;
for(u32 i = 2; i < m; i++){
if(min_factors[i] == i) powc[i] = mypow(mint(i), c);
else powc[i] = powc[min_factors[i]] * powc[i / min_factors[i]];
}
for(u32 i = 2; i <= m; i++){
if(c&1)
v[i] *= mypow(powc[i-1] - mint(i-1), k);
// v[i] *= mypow(mint(i-1), c) - mint(i-1);
else
v[i] *= mypow(powc[i-1] + mint(i-1), k);
// v[i] *= mypow(mint(i-1), c) + mint(i-1);
}
}
std::vector<mint> fact(m+1);
fact[0] = 1;
for(u32 i = 1; i <= m; i++) fact[i] = fact[i-1] * i;
std::vector<mint> ifact(m+1);
ifact[m] = inv(fact[m]);
for(int i = m-1; i >= 0; i--) ifact[i] = ifact[i+1] * (i+1);
mint ans = 0;
for(u32 i = 2; i <= m; i++){
if((m-i)&1)
ans -= fact[m] * ifact[i] * ifact[m-i] * v[i];
else
ans += fact[m] * ifact[i] * ifact[m-i] * v[i];
}
printf("%d\n", (ans * ifact[m]).n);
return 0;
}