結果

問題 No.2305 [Cherry 5th Tune N] Until That Day...
ユーザー square1001square1001
提出日時 2023-05-15 01:11:30
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 2,689 ms / 10,000 ms
コード長 6,065 bytes
コンパイル時間 1,618 ms
コンパイル使用メモリ 103,716 KB
実行使用メモリ 6,820 KB
最終ジャッジ日時 2024-12-15 19:22:37
合計ジャッジ時間 32,425 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
6,816 KB
testcase_01 AC 1 ms
6,820 KB
testcase_02 AC 5 ms
6,820 KB
testcase_03 AC 8 ms
6,820 KB
testcase_04 AC 19 ms
6,816 KB
testcase_05 AC 57 ms
6,816 KB
testcase_06 AC 204 ms
6,820 KB
testcase_07 AC 2,593 ms
6,816 KB
testcase_08 AC 2,626 ms
6,820 KB
testcase_09 AC 2,655 ms
6,820 KB
testcase_10 AC 2,641 ms
6,816 KB
testcase_11 AC 2,689 ms
6,820 KB
testcase_12 AC 2,584 ms
6,816 KB
testcase_13 AC 2,639 ms
6,820 KB
testcase_14 AC 2,680 ms
6,816 KB
testcase_15 AC 2,657 ms
6,820 KB
testcase_16 AC 2,488 ms
6,816 KB
testcase_17 AC 2 ms
6,820 KB
testcase_18 AC 2,670 ms
6,820 KB
testcase_19 AC 397 ms
6,816 KB
testcase_20 AC 407 ms
6,816 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#ifndef CLASS_MODINT
#define CLASS_MODINT

#include <cstdint>

template <std::uint32_t mod>
class modint {
private:
	std::uint32_t n;
public:
	modint() : n(0) {};
	modint(std::int64_t n_) : n((n_ >= 0 ? n_ : mod - (-n_) % mod) % mod) {};
	static constexpr std::uint32_t get_mod() { return mod; }
	std::uint32_t get() const { return n; }
	bool operator==(const modint& m) const { return n == m.n; }
	bool operator!=(const modint& m) const { return n != m.n; }
	modint& operator+=(const modint& m) { n += m.n; n = (n < mod ? n : n - mod); return *this; }
	modint& operator-=(const modint& m) { n += mod - m.n; n = (n < mod ? n : n - mod); return *this; }
	modint& operator*=(const modint& m) { n = std::uint64_t(n) * m.n % mod; return *this; }
	modint operator+(const modint& m) const { return modint(*this) += m; }
	modint operator-(const modint& m) const { return modint(*this) -= m; }
	modint operator*(const modint& m) const { return modint(*this) *= m; }
	modint inv() const { return (*this).pow(mod - 2); }
	modint pow(std::uint64_t b) const {
		modint ans = 1, m = modint(*this);
		while (b) {
			if (b & 1) ans *= m;
			m *= m;
			b >>= 1;
		}
		return ans;
	}
};

#endif // CLASS_MODINT

#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
using mint = modint<998244353>;

mint fwdpow[32], revpow[32];

void fourier_transform(vector<mint>& X, bool inv) {
	int N = X.size();
	int level = 0;
	vector<mint> Y(N);
	for (int i = N >> 1; i >= 1; i >>= 1) {
		level += 1;
		mint p = (!inv ? fwdpow[level] : revpow[level]), mul = 1;
		for (int j = 0; 2 * j < N; j += i) {
			for (int k = 0; k < i; k++) {
				mint g = mul * X[2 * j + k + i];
				Y[j + k] = X[2 * j + k] + g;
				Y[j + k + N / 2] = X[2 * j + k] - g;
			}
			mul *= p;
		}
		X = Y;
	}
}

vector<mint> convolve(vector<mint> A, vector<mint> B) {
	int L = A.size() + B.size() - 1;
	int sz = 2;
	while (sz < L) {
		sz *= 2;
	}
	A.resize(sz);
	fourier_transform(A, false);
	B.resize(sz);
	fourier_transform(B, false);
	for (int i = 0; i < sz; ++i) {
		A[i] *= B[i];
	}
	fourier_transform(A, true);
	A.resize(L);
	mint inv = mint(sz).inv();
	for (int i = 0; i < L; ++i) {
		A[i] *= inv;
	}
	return A;
}

vector<mint> polynomial_inverse(vector<mint> C, int L) {
	// (C[0] + C[1] * x + ... * C[N-1] * x^(N-1)) * P(x) == 1 (mod x^L) を満たす、L-1 次以下の多項式 P(x) を計算量 O(N log N) で求める
	// 制約: C[0] == 1 でなければならない
	int N = C.size();
	vector<mint> A = { mint(1), mint(0) };
	int level = 0;
	while ((1 << level) < L) {
		int CS = min(2 << level, N);
		vector<mint> P = convolve(A, vector<mint>(C.begin(), C.begin() + CS));
		vector<mint> Q(2 << level);
		Q[0] = 1;
		for (int j = (1 << level); j < (2 << level); ++j) {
			Q[j] = mint(0) - P[j];
		}
		A = convolve(A, Q);
		A.resize(4 << level);
		++level;
	}
	A.resize(L);
	return A;
}

class edge {
public:
	int to; mint weight;
	edge() : to(-1), weight(mint()) {}
	edge(int to_, const mint& weight_) : to(to_), weight(weight_) {}
};

int main() {
	// step #0. prepare for NTT
	for (int i = 0; i <= 23; ++i) {
		fwdpow[i] = mint(3).pow((mint::get_mod() / (1 << i)));
		revpow[i] = fwdpow[i].inv();
	}

	// step #1. read input (without queries) & make graph
	int N;
	cin >> N;
	N += 1;
	vector<int> P(N, -1);
	for (int i = 1; i < N; i++) {
		cin >> P[i];
	}
	vector<vector<edge> > G(N);
	for (int i = 1; i < N; i++) {
		int x;
		cin >> x;
		G[P[i]].push_back(edge(i, mint(x)));
	}

	auto solve = [&](int K, int mark) {
		// step #2. compute values used in dynamic programming
		vector<int> depth(N);
		depth[0] = 0;
		vector<mint> prob(N);
		prob[0] = 1;
		vector<bool> flag(N, false);
		flag[mark] = true;
		for (int i = 0; i < N; i++) {
			if (!G[i].empty()) {
				mint allmul = 0;
				for (edge e : G[i]) {
					allmul += e.weight;
				}
				allmul = prob[i] * allmul.inv();
				for (edge e : G[i]) {
					depth[e.to] = depth[i] + 1;
					prob[e.to] = e.weight * allmul;
					if (flag[i]) {
						flag[e.to] = true;
					}
				}
			}
		}

		// step #3. define polynomials
		vector<mint> v1(N + 1), v2(N + 1), v3(N), v4(N);
		v1[0] = 1;
		for (int i = 0; i < N; i++) {
			if (G[i].empty()) {
				v1[depth[i] + 1] -= prob[i];
				if (flag[i]) {
					v2[depth[i] + 1] += prob[i];
				}
			}
		}
		v3[N - 1] = mint(0) - v1[N];
		for (int i = N - 2; i >= 0; i--) {
			v3[i] = v3[i + 1] - v1[i + 1];
		}
		v4[N - 1] = v2[N];
		for (int i = N - 2; i >= depth[mark]; i--) {
			v4[i] = v4[i + 1] + v2[i + 1];
		}

		// step #4. calculate v1 * v4 + v2 * v3 and v1^2
		vector<mint> v5 = convolve(v1, v4);
		vector<mint> v6 = convolve(v2, v3);
		vector<mint> v7(2 * N);
		for (int i = 0; i < 2 * N; i++) {
			v7[i] = v5[i] + v6[i];
		}
		vector<mint> v8 = convolve(v1, v1);
		vector<mint> v9 = polynomial_inverse(v8, 2 * N + 1);

		// step #5. calculate coefficient of x^K in v7 / v8
		vector<int> track = { K };
		while (track.back() >= 2 * N) {
			track.push_back(track.back() / 2);
		}
		reverse(track.begin(), track.end());
		reverse(v7.begin(), v7.end());
		reverse(v8.begin(), v8.end());
		auto get_mod = [&](const vector<mint>& p) {
			vector<mint> p1(p.begin() + 2 * N, p.end());
			p1.resize(2 * N);
			reverse(p1.begin(), p1.end());
			vector<mint> p2 = convolve(p1, v9);
			p2.resize(2 * N);
			reverse(p2.begin(), p2.end());
			vector<mint> p3 = convolve(p2, v8);
			vector<mint> res(2 * N);
			for (int j = 0; j < 2 * N; j++) {
				res[j] = p[j] - p3[j];
			}
			return res;
		};
		vector<mint> poly(2 * N);
		poly[track[0]] = 1;
		for (int i = 1; i < int(track.size()); i++) {
			poly = convolve(poly, poly);
			if (track[i] % 2 == 1) {
				poly.insert(poly.begin(), mint(0));
			}
			poly = get_mod(poly);
		}
		poly = convolve(poly, v7);
		poly = get_mod(poly);

		// step #6. calculate answer
		mint answer = poly[2 * N - 1];
		if (mark == 0) {
			answer -= 1;
		}

		return answer;
	};

	// step #7. process queries
	int Q;
	cin >> Q;
	for (int i = 0; i < Q; i++) {
		int a, k;
		cin >> a >> k;
		mint answer = solve(k, a);
		cout << answer.get() << endl;
	}
	
	return 0;
}
0