結果
問題 | No.2305 [Cherry 5th Tune N] Until That Day... |
ユーザー | square1001 |
提出日時 | 2023-05-15 01:11:30 |
言語 | C++17 (gcc 12.3.0 + boost 1.83.0) |
結果 |
AC
|
実行時間 | 2,689 ms / 10,000 ms |
コード長 | 6,065 bytes |
コンパイル時間 | 1,618 ms |
コンパイル使用メモリ | 103,716 KB |
実行使用メモリ | 6,820 KB |
最終ジャッジ日時 | 2024-12-15 19:22:37 |
合計ジャッジ時間 | 32,425 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 2 ms
6,816 KB |
testcase_01 | AC | 1 ms
6,820 KB |
testcase_02 | AC | 5 ms
6,820 KB |
testcase_03 | AC | 8 ms
6,820 KB |
testcase_04 | AC | 19 ms
6,816 KB |
testcase_05 | AC | 57 ms
6,816 KB |
testcase_06 | AC | 204 ms
6,820 KB |
testcase_07 | AC | 2,593 ms
6,816 KB |
testcase_08 | AC | 2,626 ms
6,820 KB |
testcase_09 | AC | 2,655 ms
6,820 KB |
testcase_10 | AC | 2,641 ms
6,816 KB |
testcase_11 | AC | 2,689 ms
6,820 KB |
testcase_12 | AC | 2,584 ms
6,816 KB |
testcase_13 | AC | 2,639 ms
6,820 KB |
testcase_14 | AC | 2,680 ms
6,816 KB |
testcase_15 | AC | 2,657 ms
6,820 KB |
testcase_16 | AC | 2,488 ms
6,816 KB |
testcase_17 | AC | 2 ms
6,820 KB |
testcase_18 | AC | 2,670 ms
6,820 KB |
testcase_19 | AC | 397 ms
6,816 KB |
testcase_20 | AC | 407 ms
6,816 KB |
ソースコード
#ifndef CLASS_MODINT #define CLASS_MODINT #include <cstdint> template <std::uint32_t mod> class modint { private: std::uint32_t n; public: modint() : n(0) {}; modint(std::int64_t n_) : n((n_ >= 0 ? n_ : mod - (-n_) % mod) % mod) {}; static constexpr std::uint32_t get_mod() { return mod; } std::uint32_t get() const { return n; } bool operator==(const modint& m) const { return n == m.n; } bool operator!=(const modint& m) const { return n != m.n; } modint& operator+=(const modint& m) { n += m.n; n = (n < mod ? n : n - mod); return *this; } modint& operator-=(const modint& m) { n += mod - m.n; n = (n < mod ? n : n - mod); return *this; } modint& operator*=(const modint& m) { n = std::uint64_t(n) * m.n % mod; return *this; } modint operator+(const modint& m) const { return modint(*this) += m; } modint operator-(const modint& m) const { return modint(*this) -= m; } modint operator*(const modint& m) const { return modint(*this) *= m; } modint inv() const { return (*this).pow(mod - 2); } modint pow(std::uint64_t b) const { modint ans = 1, m = modint(*this); while (b) { if (b & 1) ans *= m; m *= m; b >>= 1; } return ans; } }; #endif // CLASS_MODINT #include <vector> #include <iostream> #include <algorithm> using namespace std; using mint = modint<998244353>; mint fwdpow[32], revpow[32]; void fourier_transform(vector<mint>& X, bool inv) { int N = X.size(); int level = 0; vector<mint> Y(N); for (int i = N >> 1; i >= 1; i >>= 1) { level += 1; mint p = (!inv ? fwdpow[level] : revpow[level]), mul = 1; for (int j = 0; 2 * j < N; j += i) { for (int k = 0; k < i; k++) { mint g = mul * X[2 * j + k + i]; Y[j + k] = X[2 * j + k] + g; Y[j + k + N / 2] = X[2 * j + k] - g; } mul *= p; } X = Y; } } vector<mint> convolve(vector<mint> A, vector<mint> B) { int L = A.size() + B.size() - 1; int sz = 2; while (sz < L) { sz *= 2; } A.resize(sz); fourier_transform(A, false); B.resize(sz); fourier_transform(B, false); for (int i = 0; i < sz; ++i) { A[i] *= B[i]; } fourier_transform(A, true); A.resize(L); mint inv = mint(sz).inv(); for (int i = 0; i < L; ++i) { A[i] *= inv; } return A; } vector<mint> polynomial_inverse(vector<mint> C, int L) { // (C[0] + C[1] * x + ... * C[N-1] * x^(N-1)) * P(x) == 1 (mod x^L) を満たす、L-1 次以下の多項式 P(x) を計算量 O(N log N) で求める // 制約: C[0] == 1 でなければならない int N = C.size(); vector<mint> A = { mint(1), mint(0) }; int level = 0; while ((1 << level) < L) { int CS = min(2 << level, N); vector<mint> P = convolve(A, vector<mint>(C.begin(), C.begin() + CS)); vector<mint> Q(2 << level); Q[0] = 1; for (int j = (1 << level); j < (2 << level); ++j) { Q[j] = mint(0) - P[j]; } A = convolve(A, Q); A.resize(4 << level); ++level; } A.resize(L); return A; } class edge { public: int to; mint weight; edge() : to(-1), weight(mint()) {} edge(int to_, const mint& weight_) : to(to_), weight(weight_) {} }; int main() { // step #0. prepare for NTT for (int i = 0; i <= 23; ++i) { fwdpow[i] = mint(3).pow((mint::get_mod() / (1 << i))); revpow[i] = fwdpow[i].inv(); } // step #1. read input (without queries) & make graph int N; cin >> N; N += 1; vector<int> P(N, -1); for (int i = 1; i < N; i++) { cin >> P[i]; } vector<vector<edge> > G(N); for (int i = 1; i < N; i++) { int x; cin >> x; G[P[i]].push_back(edge(i, mint(x))); } auto solve = [&](int K, int mark) { // step #2. compute values used in dynamic programming vector<int> depth(N); depth[0] = 0; vector<mint> prob(N); prob[0] = 1; vector<bool> flag(N, false); flag[mark] = true; for (int i = 0; i < N; i++) { if (!G[i].empty()) { mint allmul = 0; for (edge e : G[i]) { allmul += e.weight; } allmul = prob[i] * allmul.inv(); for (edge e : G[i]) { depth[e.to] = depth[i] + 1; prob[e.to] = e.weight * allmul; if (flag[i]) { flag[e.to] = true; } } } } // step #3. define polynomials vector<mint> v1(N + 1), v2(N + 1), v3(N), v4(N); v1[0] = 1; for (int i = 0; i < N; i++) { if (G[i].empty()) { v1[depth[i] + 1] -= prob[i]; if (flag[i]) { v2[depth[i] + 1] += prob[i]; } } } v3[N - 1] = mint(0) - v1[N]; for (int i = N - 2; i >= 0; i--) { v3[i] = v3[i + 1] - v1[i + 1]; } v4[N - 1] = v2[N]; for (int i = N - 2; i >= depth[mark]; i--) { v4[i] = v4[i + 1] + v2[i + 1]; } // step #4. calculate v1 * v4 + v2 * v3 and v1^2 vector<mint> v5 = convolve(v1, v4); vector<mint> v6 = convolve(v2, v3); vector<mint> v7(2 * N); for (int i = 0; i < 2 * N; i++) { v7[i] = v5[i] + v6[i]; } vector<mint> v8 = convolve(v1, v1); vector<mint> v9 = polynomial_inverse(v8, 2 * N + 1); // step #5. calculate coefficient of x^K in v7 / v8 vector<int> track = { K }; while (track.back() >= 2 * N) { track.push_back(track.back() / 2); } reverse(track.begin(), track.end()); reverse(v7.begin(), v7.end()); reverse(v8.begin(), v8.end()); auto get_mod = [&](const vector<mint>& p) { vector<mint> p1(p.begin() + 2 * N, p.end()); p1.resize(2 * N); reverse(p1.begin(), p1.end()); vector<mint> p2 = convolve(p1, v9); p2.resize(2 * N); reverse(p2.begin(), p2.end()); vector<mint> p3 = convolve(p2, v8); vector<mint> res(2 * N); for (int j = 0; j < 2 * N; j++) { res[j] = p[j] - p3[j]; } return res; }; vector<mint> poly(2 * N); poly[track[0]] = 1; for (int i = 1; i < int(track.size()); i++) { poly = convolve(poly, poly); if (track[i] % 2 == 1) { poly.insert(poly.begin(), mint(0)); } poly = get_mod(poly); } poly = convolve(poly, v7); poly = get_mod(poly); // step #6. calculate answer mint answer = poly[2 * N - 1]; if (mark == 0) { answer -= 1; } return answer; }; // step #7. process queries int Q; cin >> Q; for (int i = 0; i < Q; i++) { int a, k; cin >> a >> k; mint answer = solve(k, a); cout << answer.get() << endl; } return 0; }