結果
問題 |
No.1762 🐙🐄🌲
|
ユーザー |
![]() |
提出日時 | 2025-05-14 13:06:17 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 1,367 ms / 4,000 ms |
コード長 | 9,720 bytes |
コンパイル時間 | 1,224 ms |
コンパイル使用メモリ | 82,516 KB |
実行使用メモリ | 28,392 KB |
最終ジャッジ日時 | 2025-05-14 13:07:58 |
合計ジャッジ時間 | 32,021 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 47 |
ソースコード
#include <iostream> #include <vector> #include <numeric> #include <algorithm> // for std::swap, std::min // --- Modular Arithmetic --- const int MOD = 998244353; // Computes base^exp % MOD efficiently long long power(long long base, long long exp) { long long res = 1; base %= MOD; while (exp > 0) { if (exp % 2 == 1) res = (res * base) % MOD; base = (base * base) % MOD; exp /= 2; } return res; } // Computes modular inverse using Fermat's Little Theorem long long modInverse(long long n) { // Assumes n % MOD != 0. This is safe for this problem based on constraints analysis. return power(n, MOD - 2); } // --- Number Theoretic Transform (NTT) --- // Use NTT_LOG = 18, so NTT_SIZE = 2^18 = 262144. // This size is sufficient because max K' is around 1.25e5, and NTT multiplication needs size > 2*K'. 2^18 > 2 * 1.25e5. const int NTT_LOG = 18; const int NTT_SIZE = 1 << NTT_LOG; const int G = 3; // Primitive root for 998244353 long long W[NTT_SIZE], W_inv[NTT_SIZE]; // Roots of unity and their inverses int rev[NTT_SIZE]; // Bit reversal permutation indices // Precompute roots of unity and bit reversal permutation indices for NTT void precompute_ntt() { long long G_pow = power(G, (MOD - 1) / NTT_SIZE); // G^((MOD-1)/NTT_SIZE) long long G_inv_pow = modInverse(G_pow); // Inverse root W[0] = W_inv[0] = 1; for (int i = 1; i < NTT_SIZE; ++i) { W[i] = (W[i - 1] * G_pow) % MOD; W_inv[i] = (W_inv[i - 1] * G_inv_pow) % MOD; } // Precompute bit reversal permutation indices for (int i = 0; i < NTT_SIZE; ++i) { rev[i] = 0; for(int j=0; j<NTT_LOG; ++j) { if ((i >> j) & 1) { // Check j-th bit of i rev[i] |= (1 << (NTT_LOG - 1 - j)); // Set (LOG-1-j)-th bit of rev[i] } } } } // Performs NTT or Inverse NTT // `a` is the vector of coefficients, `invert` flag determines forward (false) or inverse (true) transform. // Assumes `a` has size NTT_SIZE (or padded to it). void ntt(std::vector<long long>& a, bool invert) { int n = NTT_SIZE; // Use fixed NTT size if (a.size() < n) a.resize(n, 0); // Pad with zeros if smaller than NTT_SIZE // Apply bit reversal permutation for (int i = 0; i < n; ++i) { if (i < rev[i]) { std::swap(a[i], a[rev[i]]); } } long long* roots = invert ? W_inv : W; // Choose roots based on direction // Butterfly operations for (int len = 2; len <= n; len <<= 1) { // Iterate through lengths 2, 4, ..., n int step = NTT_SIZE / len; // Step size to pick roots int half_len = len >> 1; for (int i = 0; i < n; i += len) { // Iterate through blocks for (int j = 0; j < half_len; j++) { // Iterate within block halves long long w = roots[j * step]; // Get appropriate root of unity long long u = a[i + j]; long long v = (a[i + j + half_len] * w) % MOD; a[i + j] = (u + v) % MOD; // Combine results a[i + j + half_len] = (u - v + MOD) % MOD; // Ensure positive result } } } // Scale by 1/n if inverse transform if (invert) { long long n_inv = modInverse(n); for (int i=0; i<n; ++i) { a[i] = (a[i] * n_inv) % MOD; } } } // Multiplies two polynomials `a` and `b` using NTT. // Returns the resulting polynomial truncated to degree `result_len - 1`. // `result_len` is the number of coefficients (degree + 1). std::vector<long long> multiply(const std::vector<long long>& a, const std::vector<long long>& b, int result_len) { // Copy input polynomials and resize to NTT_SIZE, padding with zeros std::vector<long long> fa(a); fa.resize(NTT_SIZE, 0); std::vector<long long> fb(b); fb.resize(NTT_SIZE, 0); ntt(fa, false); // Forward NTT on fa ntt(fb, false); // Forward NTT on fb // Pointwise multiplication in frequency domain std::vector<long long> result(NTT_SIZE); for (int i = 0; i < NTT_SIZE; i++) result[i] = (fa[i] * fb[i]) % MOD; ntt(result, true); // Inverse NTT to get coefficients result.resize(result_len); // Truncate result to the required length return result; } // Computes polynomial exponentiation: base^exp mod x^result_len // `result_len` is the maximum number of coefficients needed. std::vector<long long> poly_pow(std::vector<long long> base, long long exp, int result_len) { std::vector<long long> res(result_len); // Result polynomial initialized to 0 if (result_len > 0) res[0] = 1; else return {}; // If result_len is 0, return empty. Otherwise, start with polynomial 1. // Ensure base polynomial does not exceed required length if (base.size() > result_len) base.resize(result_len); // Standard binary exponentiation (exponentiation by squaring) while (exp > 0) { if (exp % 2 == 1) { // If exponent is odd res = multiply(res, base, result_len); // Multiply result by base } // Square base for next iteration, only if needed if (exp > 1) { base = multiply(base, base, result_len); } exp /= 2; // Divide exponent by 2 } res.resize(result_len); // Ensure final result has the correct size return res; } // --- Factorials --- const int MAX_N_Factorial = 500001; // Maximum N is 5e5, need factorials up to N long long fact[MAX_N_Factorial]; // fact[i] = i! mod MOD long long invFact[MAX_N_Factorial]; // invFact[i] = (i!)^-1 mod MOD // Precompute factorials and their modular inverses up to MAX_N_Factorial - 1 void precompute_factorials_optimized() { fact[0] = 1; for (int i = 1; i < MAX_N_Factorial; ++i) { fact[i] = (fact[i - 1] * i) % MOD; } // Compute inverse of N! using Fermat's Little Theorem invFact[MAX_N_Factorial - 1] = modInverse(fact[MAX_N_Factorial - 1]); // Compute other inverse factorials iteratively: invFact[i] = invFact[i+1] * (i+1) for (int i = MAX_N_Factorial - 2; i >= 0; --i) { invFact[i] = (invFact[i + 1] * (i + 1)) % MOD; } } // --- Main Logic --- int main() { std::ios_base::sync_with_stdio(false); // Faster I/O std::cin.tie(NULL); int N; long long P; // P can be up to N, use long long std::cin >> N >> P; // Condition check: N must be at least 2 and N-1 must be divisible by 4 // If N=1 mod 4 doesn't hold, or N < 2, no such tree exists. // The smallest N satisfying N>=2 and N=1 mod 4 is N=5. if ((N - 1) % 4 != 0 || N < 2) { std::cout << 0 << std::endl; return 0; } // Calculate number of Octopus (N_O) and Cow (N_C) vertices long long N_O = (3LL * N + 1) / 4; long long N_C = (N - 1) / 4; // K is related to the sum of (degree-1) for Octopus vertices long long K = (N - 5) / 4; // Check if K is valid. For N=5, K=0. Minimal N is 5, so K >= 0. // K' is the target sum of (degree-1) for non-Perfect Octopus vertices long long K_prime = K - 7LL * P; // If K' is negative, it's impossible to satisfy degree sum constraints. if (K_prime < 0) { std::cout << 0 << std::endl; return 0; } // P must be non-negative and cannot exceed the number of Octopus vertices. if (P < 0 || P > N_O) { std::cout << 0 << std::endl; return 0; } // Precompute necessary values precompute_factorials_optimized(); precompute_ntt(); // Define the base polynomial f(x) = sum_{k=0..6} x^k / k! std::vector<long long> f(7); for (int k = 0; k <= 6; ++k) { f[k] = invFact[k]; } // Number of non-Perfect Octopus vertices long long N_O_minus_P = N_O - P; std::vector<long long> g; // Result of polynomial exponentiation int required_poly_len = K_prime + 1; // Need coefficient of x^K', so need poly up to degree K' // Check if required polynomial length exceeds NTT capability (sanity check) if (required_poly_len > NTT_SIZE) { std::cerr << "Error: Required polynomial length " << required_poly_len << " exceeds NTT size " << NTT_SIZE << "." << std::endl; return 1; // Should not happen based on constraints } // Compute g(x) = f(x)^(N_O - P) if (N_O_minus_P == 0) { // Special case: exponent is 0 g.resize(required_poly_len, 0); if (K_prime == 0) g[0] = 1; // f(x)^0 = 1. Coeff of x^0 is 1. All others 0. } else { g = poly_pow(f, N_O_minus_P, required_poly_len); } // Extract the coefficient C_{K'} = [x^{K'}] g(x) long long C_K_prime = 0; if (K_prime < g.size()) { // Check index boundary C_K_prime = g[K_prime]; } // If K_prime >= g.size(), means K_prime >= required_poly_len. This coefficient is implicitly 0. // Precompute modular inverses of constants 6 and 5040 (7!) long long inv6 = modInverse(6); long long inv5040 = modInverse(5040); // Calculate the final answer using the derived formula: // TotalCount = (N! * (N_O-1)!) / (N_C * 6^N_C * P! * (N_O-P)! * 5040^P) * C_{K'} long long ans = fact[N]; ans = (ans * fact[N_O - 1]) % MOD; // N_C = (N-1)/4. Since N>=5, N_C >= 1. So N_C is non-zero. ans = (ans * modInverse(N_C)) % MOD; ans = (ans * power(inv6, N_C)) % MOD; // (1/6)^N_C ans = (ans * invFact[P]) % MOD; // 1/P! // P <= N_O ensures N_O - P >= 0. Index is valid. ans = (ans * invFact[N_O - P]) % MOD; // 1/((N_O-P)!) ans = (ans * power(inv5040, P)) % MOD; // (1/5040)^P ans = (ans * C_K_prime) % MOD; // Multiply by the coefficient computed std::cout << ans << std::endl; return 0; }