結果
| 問題 |
No.2305 [Cherry 5th Tune N] Until That Day...
|
| コンテスト | |
| ユーザー |
square1001
|
| 提出日時 | 2023-05-15 01:11:30 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 2,788 ms / 10,000 ms |
| コード長 | 6,065 bytes |
| コンパイル時間 | 1,331 ms |
| コンパイル使用メモリ | 102,056 KB |
| 最終ジャッジ日時 | 2025-02-13 00:37:29 |
|
ジャッジサーバーID (参考情報) |
judge5 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 21 |
ソースコード
#ifndef CLASS_MODINT
#define CLASS_MODINT
#include <cstdint>
template <std::uint32_t mod>
class modint {
private:
std::uint32_t n;
public:
modint() : n(0) {};
modint(std::int64_t n_) : n((n_ >= 0 ? n_ : mod - (-n_) % mod) % mod) {};
static constexpr std::uint32_t get_mod() { return mod; }
std::uint32_t get() const { return n; }
bool operator==(const modint& m) const { return n == m.n; }
bool operator!=(const modint& m) const { return n != m.n; }
modint& operator+=(const modint& m) { n += m.n; n = (n < mod ? n : n - mod); return *this; }
modint& operator-=(const modint& m) { n += mod - m.n; n = (n < mod ? n : n - mod); return *this; }
modint& operator*=(const modint& m) { n = std::uint64_t(n) * m.n % mod; return *this; }
modint operator+(const modint& m) const { return modint(*this) += m; }
modint operator-(const modint& m) const { return modint(*this) -= m; }
modint operator*(const modint& m) const { return modint(*this) *= m; }
modint inv() const { return (*this).pow(mod - 2); }
modint pow(std::uint64_t b) const {
modint ans = 1, m = modint(*this);
while (b) {
if (b & 1) ans *= m;
m *= m;
b >>= 1;
}
return ans;
}
};
#endif // CLASS_MODINT
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
using mint = modint<998244353>;
mint fwdpow[32], revpow[32];
void fourier_transform(vector<mint>& X, bool inv) {
int N = X.size();
int level = 0;
vector<mint> Y(N);
for (int i = N >> 1; i >= 1; i >>= 1) {
level += 1;
mint p = (!inv ? fwdpow[level] : revpow[level]), mul = 1;
for (int j = 0; 2 * j < N; j += i) {
for (int k = 0; k < i; k++) {
mint g = mul * X[2 * j + k + i];
Y[j + k] = X[2 * j + k] + g;
Y[j + k + N / 2] = X[2 * j + k] - g;
}
mul *= p;
}
X = Y;
}
}
vector<mint> convolve(vector<mint> A, vector<mint> B) {
int L = A.size() + B.size() - 1;
int sz = 2;
while (sz < L) {
sz *= 2;
}
A.resize(sz);
fourier_transform(A, false);
B.resize(sz);
fourier_transform(B, false);
for (int i = 0; i < sz; ++i) {
A[i] *= B[i];
}
fourier_transform(A, true);
A.resize(L);
mint inv = mint(sz).inv();
for (int i = 0; i < L; ++i) {
A[i] *= inv;
}
return A;
}
vector<mint> polynomial_inverse(vector<mint> C, int L) {
// (C[0] + C[1] * x + ... * C[N-1] * x^(N-1)) * P(x) == 1 (mod x^L) を満たす、L-1 次以下の多項式 P(x) を計算量 O(N log N) で求める
// 制約: C[0] == 1 でなければならない
int N = C.size();
vector<mint> A = { mint(1), mint(0) };
int level = 0;
while ((1 << level) < L) {
int CS = min(2 << level, N);
vector<mint> P = convolve(A, vector<mint>(C.begin(), C.begin() + CS));
vector<mint> Q(2 << level);
Q[0] = 1;
for (int j = (1 << level); j < (2 << level); ++j) {
Q[j] = mint(0) - P[j];
}
A = convolve(A, Q);
A.resize(4 << level);
++level;
}
A.resize(L);
return A;
}
class edge {
public:
int to; mint weight;
edge() : to(-1), weight(mint()) {}
edge(int to_, const mint& weight_) : to(to_), weight(weight_) {}
};
int main() {
// step #0. prepare for NTT
for (int i = 0; i <= 23; ++i) {
fwdpow[i] = mint(3).pow((mint::get_mod() / (1 << i)));
revpow[i] = fwdpow[i].inv();
}
// step #1. read input (without queries) & make graph
int N;
cin >> N;
N += 1;
vector<int> P(N, -1);
for (int i = 1; i < N; i++) {
cin >> P[i];
}
vector<vector<edge> > G(N);
for (int i = 1; i < N; i++) {
int x;
cin >> x;
G[P[i]].push_back(edge(i, mint(x)));
}
auto solve = [&](int K, int mark) {
// step #2. compute values used in dynamic programming
vector<int> depth(N);
depth[0] = 0;
vector<mint> prob(N);
prob[0] = 1;
vector<bool> flag(N, false);
flag[mark] = true;
for (int i = 0; i < N; i++) {
if (!G[i].empty()) {
mint allmul = 0;
for (edge e : G[i]) {
allmul += e.weight;
}
allmul = prob[i] * allmul.inv();
for (edge e : G[i]) {
depth[e.to] = depth[i] + 1;
prob[e.to] = e.weight * allmul;
if (flag[i]) {
flag[e.to] = true;
}
}
}
}
// step #3. define polynomials
vector<mint> v1(N + 1), v2(N + 1), v3(N), v4(N);
v1[0] = 1;
for (int i = 0; i < N; i++) {
if (G[i].empty()) {
v1[depth[i] + 1] -= prob[i];
if (flag[i]) {
v2[depth[i] + 1] += prob[i];
}
}
}
v3[N - 1] = mint(0) - v1[N];
for (int i = N - 2; i >= 0; i--) {
v3[i] = v3[i + 1] - v1[i + 1];
}
v4[N - 1] = v2[N];
for (int i = N - 2; i >= depth[mark]; i--) {
v4[i] = v4[i + 1] + v2[i + 1];
}
// step #4. calculate v1 * v4 + v2 * v3 and v1^2
vector<mint> v5 = convolve(v1, v4);
vector<mint> v6 = convolve(v2, v3);
vector<mint> v7(2 * N);
for (int i = 0; i < 2 * N; i++) {
v7[i] = v5[i] + v6[i];
}
vector<mint> v8 = convolve(v1, v1);
vector<mint> v9 = polynomial_inverse(v8, 2 * N + 1);
// step #5. calculate coefficient of x^K in v7 / v8
vector<int> track = { K };
while (track.back() >= 2 * N) {
track.push_back(track.back() / 2);
}
reverse(track.begin(), track.end());
reverse(v7.begin(), v7.end());
reverse(v8.begin(), v8.end());
auto get_mod = [&](const vector<mint>& p) {
vector<mint> p1(p.begin() + 2 * N, p.end());
p1.resize(2 * N);
reverse(p1.begin(), p1.end());
vector<mint> p2 = convolve(p1, v9);
p2.resize(2 * N);
reverse(p2.begin(), p2.end());
vector<mint> p3 = convolve(p2, v8);
vector<mint> res(2 * N);
for (int j = 0; j < 2 * N; j++) {
res[j] = p[j] - p3[j];
}
return res;
};
vector<mint> poly(2 * N);
poly[track[0]] = 1;
for (int i = 1; i < int(track.size()); i++) {
poly = convolve(poly, poly);
if (track[i] % 2 == 1) {
poly.insert(poly.begin(), mint(0));
}
poly = get_mod(poly);
}
poly = convolve(poly, v7);
poly = get_mod(poly);
// step #6. calculate answer
mint answer = poly[2 * N - 1];
if (mark == 0) {
answer -= 1;
}
return answer;
};
// step #7. process queries
int Q;
cin >> Q;
for (int i = 0; i < Q; i++) {
int a, k;
cin >> a >> k;
mint answer = solve(k, a);
cout << answer.get() << endl;
}
return 0;
}
square1001