結果
問題 | No.1145 Sums of Powers |
ユーザー | 👑 hos.lyric |
提出日時 | 2020-07-31 22:52:41 |
言語 | C++14 (gcc 12.3.0 + boost 1.83.0) |
結果 |
WA
|
実行時間 | - |
コード長 | 6,942 bytes |
コンパイル時間 | 4,917 ms |
コンパイル使用メモリ | 271,248 KB |
実行使用メモリ | 7,392 KB |
最終ジャッジ日時 | 2024-07-06 20:20:16 |
合計ジャッジ時間 | 6,407 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | WA | - |
testcase_01 | WA | - |
testcase_02 | WA | - |
testcase_03 | WA | - |
testcase_04 | WA | - |
testcase_05 | WA | - |
ソースコード
#include <cassert> #include <cmath> #include <cstdint> #include <cstdio> #include <cstdlib> #include <cstring> #include <algorithm> #include <bitset> #include <complex> #include <deque> #include <functional> #include <iostream> #include <map> #include <numeric> #include <queue> #include <set> #include <sstream> #include <string> #include <unordered_map> #include <unordered_set> #include <utility> #include <vector> using namespace std; using Int = long long; template <class T1, class T2> ostream &operator<<(ostream &os, const pair<T1, T2> &a) { return os << "(" << a.first << ", " << a.second << ")"; }; template <class T> void pv(T a, T b) { for (T i = a; i != b; ++i) cerr << *i << " "; cerr << endl; } template <class T> bool chmin(T &t, const T &f) { if (t > f) { t = f; return true; } return false; } template <class T> bool chmax(T &t, const T &f) { if (t < f) { t = f; return true; } return false; } template<int M_> struct ModInt { static constexpr int M = M_; int x; constexpr ModInt() : x(0) {} constexpr ModInt(long long x_) : x(x_ % M) { if (x < 0) x += M; } ModInt &operator+=(const ModInt &a) { x += a.x; if (x >= M) x -= M; return *this; } ModInt &operator-=(const ModInt &a) { x -= a.x; if (x < 0) x += M; return *this; } ModInt &operator*=(const ModInt &a) { x = static_cast<int>((static_cast<long long>(x) * a.x) % M); return *this; } ModInt &operator/=(const ModInt &a) { return (*this *= a.inv()); } ModInt operator+(const ModInt &a) const { return (ModInt(*this) += a); } ModInt operator-(const ModInt &a) const { return (ModInt(*this) -= a); } ModInt operator*(const ModInt &a) const { return (ModInt(*this) *= a); } ModInt operator/(const ModInt &a) const { return (ModInt(*this) /= a); } ModInt operator-() const { return ModInt(-x); } ModInt pow(long long e) const { ModInt x2 = x, xe = 1; for (; e; e >>= 1) { if (e & 1) xe *= x2; x2 *= x2; } return xe; } ModInt inv() const { int a = x, b = M, y = 1, z = 0, t; for (; ; ) { t = a / b; a -= t * b; if (a == 0) { assert(b == 1 || b == -1); return ModInt(b * z); } y -= t * z; t = b / a; b -= t * a; if (b == 0) { assert(a == 1 || a == -1); return ModInt(a * y); } z -= t * y; } } friend ModInt operator+(long long a, const ModInt &b) { return (ModInt(a) += b); } friend ModInt operator-(long long a, const ModInt &b) { return (ModInt(a) -= b); } friend ModInt operator*(long long a, const ModInt &b) { return (ModInt(a) *= b); } friend std::ostream &operator<<(std::ostream &os, const ModInt &a) { return os << a.x; } }; constexpr int MO = 998244353; using Mint = ModInt<MO>; // M: prime, G: primitive root template <int M, int G, int K> struct Fft { // 1, 1/4, 1/8, 3/8, 1/16, 5/16, 3/16, 7/16, ... int gs[1 << (K - 1)]; constexpr Fft() : gs() { static_assert(2 <= K && K <= 30, "Fft: 2 <= K <= 30 must hold"); static_assert(!((M - 1) & ((1 << K) - 1)), "Fft: 2^K | M - 1 must hold"); gs[0] = 1; long long g2 = G, gg = 1; for (int e = (M - 1) >> K; e; e >>= 1) { if (e & 1) gg = (gg * g2) % M; g2 = (g2 * g2) % M; } gs[1 << (K - 2)] = gg; for (int l = 1 << (K - 2); l >= 2; l >>= 1) { gs[l >> 1] = (static_cast<long long>(gs[l]) * gs[l]) % M; } assert((static_cast<long long>(gs[1]) * gs[1]) % M == M - 1); for (int l = 2; l <= 1 << (K - 2); l <<= 1) { for (int i = 1; i < l; ++i) { gs[l + i] = (static_cast<long long>(gs[l]) * gs[i]) % M; } } } void fft(vector<int> &xs) const { const int n = xs.size(); assert(!(n & (n - 1)) && n <= 1 << K); for (int l = n; l >>= 1; ) { for (int i = 0; i < (n >> 1) / l; ++i) { const long long g = gs[i]; for (int j = (i << 1) * l; j < (i << 1 | 1) * l; ++j) { const int t = (g * xs[j + l]) % M; if ((xs[j + l] = xs[j] - t) < 0) xs[j + l] += M; if ((xs[j] += t) >= M) xs[j] -= M; } } } } void invFft(vector<int> &xs) const { const int n = xs.size(); assert(!(n & (n - 1)) && n <= 1 << K); for (int l = 1; l < n; l <<= 1) { std::reverse(xs.begin() + l, xs.begin() + (l << 1)); } for (int l = 1; l < n; l <<= 1) { for (int i = 0; i < (n >> 1) / l; ++i) { const long long g = gs[i]; for (int j = (i << 1) * l; j < (i << 1 | 1) * l; ++j) { int t = (g * (xs[j] - xs[j + l])) % M; if (t < 0) t += M; if ((xs[j] += xs[j + l]) >= M) xs[j] -= M; xs[j + l] = t; } } } } template<class T> vector<T> convolute(const vector<T> &as, const vector<T> &bs) const { const int na = as.size(), nb = bs.size(); int n, invN = 1; for (n = 1; n < na + nb - 1; n <<= 1) { invN = ((invN & 1) ? (invN + M) : invN) >> 1; } vector<int> xs(n, 0), ys(n, 0); for (int i = 0; i < na; ++i) if ((xs[i] = as[i] % M) < 0) xs[i] += M; for (int i = 0; i < nb; ++i) if ((ys[i] = bs[i] % M) < 0) ys[i] += M; fft(xs); fft(ys); for (int i = 0; i < n; ++i) { xs[i] = (((static_cast<long long>(xs[i]) * ys[i]) % M) * invN) % M; } invFft(xs); xs.resize(na + nb - 1); return xs; } vector<Mint> convolute(const vector<Mint> &as, const vector<Mint> &bs) const { const int na = as.size(), nb = bs.size(); int n, invN = 1; for (n = 1; n < na + nb - 1; n <<= 1) { invN = ((invN & 1) ? (invN + M) : invN) >> 1; } vector<int> xs(n, 0), ys(n, 0); for (int i = 0; i < na; ++i) xs[i] = as[i].x; for (int i = 0; i < nb; ++i) ys[i] = bs[i].x; fft(xs); fft(ys); for (int i = 0; i < n; ++i) { xs[i] = (((static_cast<long long>(xs[i]) * ys[i]) % M) * invN) % M; } invFft(xs); vector<Mint> ret(na + nb - 1); for (int i = 0; i < na + nb - 1; ++i) ret[i].x = xs[i]; return ret; } }; const Fft<998244353, 3, 20> FFT; int N, M; vector<int> A; pair<vector<Mint>, vector<Mint>> solve(int l, int r) { if (r - l == 1) { return make_pair(vector<Mint>{1, -A[l]}, vector<Mint>{1}); } else { const int mid = (l + r) / 2; const auto resL = solve(l, mid); const auto resR = solve(mid, r); const auto ret0 = FFT.convolute(resL.first, resR.first); auto ret1 = FFT.convolute(resL.second, resR.first); const auto tmp = FFT.convolute(resL.first, resR.second); for (int i = 0; i < r - l; ++i) { ret1[i] += tmp[i]; } return make_pair(ret0, ret1); } } int main() { for (; ~scanf("%d%d", &N, &M); ) { A.resize(N); for (int i = 0; i < N; ++i) { scanf("%d", &A[i]); } const auto res = solve(0, N); // cerr<<"res[0] = ";pv(res.first.begin(),res.first.end()); // cerr<<"res[1] = ";pv(res.second.begin(),res.second.end()); // const auto ans = FFT.convolute(inv(res.first, M + 1), res.second); } return 0; }