結果
| 問題 |
No.1762 🐙🐄🌲
|
| コンテスト | |
| ユーザー |
vjudge1
|
| 提出日時 | 2025-10-03 02:26:39 |
| 言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 8,890 bytes |
| コンパイル時間 | 1,792 ms |
| コンパイル使用メモリ | 125,912 KB |
| 実行使用メモリ | 20,908 KB |
| 最終ジャッジ日時 | 2025-10-03 02:26:53 |
| 合計ジャッジ時間 | 6,447 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 WA * 2 |
| other | AC * 20 WA * 27 |
ソースコード
#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;
// ??
const int MOD = 998244353;
// ??
const int G = 3;
// ??? (Modular Exponentiation)
long long power(long long base, long long exp) {
long long res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
// ??? (Modular Inverse)
long long modInverse(long long n) {
return power(n, MOD - 2);
}
// ?????????
vector<long long> fact;
vector<long long> invFact;
// ?????????
void precompute_factorials(int n) {
fact.resize(n + 1);
invFact.resize(n + 1);
fact[0] = 1;
for (int i = 1; i <= n; i++) {
fact[i] = (fact[i - 1] * i) % MOD;
}
invFact[n] = modInverse(fact[n]);
for (int i = n - 1; i >= 0; i--) {
invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
}
}
// ??? C(n, r)
long long nCr(int n, int r) {
if (r < 0 || r > n) return 0;
return (((fact[n] * invFact[r]) % MOD) * invFact[n - r]) % MOD;
}
// --- NTT (Number Theoretic Transform) Implementation ---
// ????? (Bit-reversal permutation)
void bit_reverse(vector<long long>& a, int n) {
int j = 0;
for (int i = 1; i < n; i++) {
int bit = n >> 1;
while (j & bit) {
j ^= bit;
bit >>= 1;
}
j ^= bit;
if (i < j) {
swap(a[i], a[j]);
}
}
}
// NTT ???
// invert: true for Inverse NTT
void ntt(vector<long long>& a, bool invert) {
int n = a.size();
// 1. ?????
bit_reverse(a, n);
// 2. ???? (Butterfly operations)
for (int len = 2; len <= n; len <<= 1) {
// len ??????????
long long wlen = power(G, (MOD - 1) / len);
if (invert) wlen = modInverse(wlen);
for (int i = 0; i < n; i += len) {
long long w = 1;
for (int j = 0; j < len / 2; j++) {
long long u = a[i + j];
long long v = (a[i + j + len / 2] * w) % MOD;
a[i + j] = (u + v) % MOD;
a[i + j + len / 2] = (u - v + MOD) % MOD;
w = (w * wlen) % MOD;
}
}
}
// 3. ? NTT ????
if (invert) {
long long n_inv = modInverse(n);
for (long long& x : a) {
x = (x * n_inv) % MOD;
}
}
}
// ????? (Polynomial Multiplication)
vector<long long> poly_mul(vector<long long> a, vector<long long> b) {
int deg = a.size() + b.size() - 1;
int n = 1;
while (n < deg) n <<= 1;
a.resize(n, 0);
b.resize(n, 0);
ntt(a, false);
ntt(b, false);
vector<long long> res(n);
for (int i = 0; i < n; i++) {
res[i] = (a[i] * b[i]) % MOD;
}
ntt(res, true);
res.resize(deg);
return res;
}
// ????? (Polynomial Inverse)
// ?? A^{-1} mod x^n
vector<long long> poly_inv(const vector<long long>& a, int n) {
if (n == 1) return {modInverse(a[0])};
// ??? A_0^{-1} mod x^(n/2)
vector<long long> a0_inv = poly_inv(a, (n + 1) / 2);
int m = 1;
while (m < 2 * n) m <<= 1;
// A mod x^n
vector<long long> A(a.begin(), a.begin() + min((int)a.size(), n));
A.resize(m, 0);
// A_0^{-1} mod x^n
a0_inv.resize(m, 0);
ntt(A, false);
ntt(a0_inv, false);
vector<long long> B(m);
for (int i = 0; i < m; i++) {
// B = (2 * A_0^{-1} - A * (A_0^{-1})^2) mod x^n
// B = A_0^{-1} * (2 - A * A_0^{-1}) mod x^n
long long term = (A[i] * a0_inv[i]) % MOD;
long long factor = (2 - term + MOD) % MOD;
B[i] = (a0_inv[i] * factor) % MOD;
}
ntt(B, true);
B.resize(n);
return B;
}
// ????? (Polynomial Derivative)
vector<long long> poly_der(const vector<long long>& a) {
if (a.empty() || a.size() == 1) return {};
vector<long long> res(a.size() - 1);
for (int i = 1; i < a.size(); i++) {
res[i - 1] = (a[i] * i) % MOD;
}
return res;
}
// ????? (Polynomial Integral)
vector<long long> poly_int(const vector<long long>& a) {
if (a.empty()) return {0};
vector<long long> res(a.size() + 1, 0);
for (int i = 0; i < a.size(); i++) {
res[i + 1] = (a[i] * modInverse(i + 1)) % MOD;
}
return res;
}
// ????? (Polynomial Logarithm)
// ?? ln(A) mod x^n. ?? A[0] = 1
vector<long long> poly_ln(const vector<long long>& a, int n) {
if (a[0] != 1) return {}; // ?? A[0] = 1
// ln(A) = integral(A' * A^{-1})
vector<long long> a_inv = poly_inv(a, n); // A^{-1} mod x^n
vector<long long> a_der = poly_der(a); // A'
// ?? size ??????
int size = 1;
while (size < n + a_der.size()) size <<= 1;
vector<long long> a_inv_ntt = a_inv;
a_inv_ntt.resize(size, 0);
vector<long long> a_der_ntt = a_der;
a_der_ntt.resize(size, 0);
ntt(a_inv_ntt, false);
ntt(a_der_ntt, false);
vector<long long> mul_res(size);
for (int i = 0; i < size; i++) {
mul_res[i] = (a_der_ntt[i] * a_inv_ntt[i]) % MOD;
}
ntt(mul_res, true);
mul_res.resize(n - 1); // ?? degree N-2
return poly_int(mul_res); // ??? degree N-1
}
// ????? (Polynomial Exponentiation)
// ?? exp(A) mod x^n. ?? A[0] = 0
vector<long long> poly_exp(const vector<long long>& a, int n) {
if (n == 1) return {1};
// ??? H_0 = exp(A) mod x^(n/2)
vector<long long> h0 = poly_exp(a, (n + 1) / 2);
// L = ln(H_0) mod x^n
vector<long long> L = poly_ln(h0, n);
// R = A - L mod x^n
vector<long long> R(n);
for (int i = 0; i < n; i++) {
long long a_i = i < a.size() ? a[i] : 0;
long long l_i = i < L.size() ? L[i] : 0;
R[i] = (a_i - l_i + MOD) % MOD;
}
R[0] = (R[0] + 1) % MOD; // H = H_0 * (1 + A - ln(H_0)) mod x^n
// ????? H = H_0 * R mod x^n
int size = 1;
while (size < 2 * n) size <<= 1;
vector<long long> h0_ntt = h0;
h0_ntt.resize(size, 0);
vector<long long> R_ntt = R;
R_ntt.resize(size, 0);
ntt(h0_ntt, false);
ntt(R_ntt, false);
vector<long long> res(size);
for (int i = 0; i < size; i++) {
res[i] = (h0_ntt[i] * R_ntt[i]) % MOD;
}
ntt(res, true);
res.resize(n);
return res;
}
// ?????? (Polynomial Exponentiation G^Q)
// G(x)^Q mod x^n
vector<long long> poly_q_pow(const vector<long long>& G, long long Q, int n) {
// ?? G[0] = 1
// G^Q = exp(Q * ln(G))
// 1. ?? ln(G) mod x^n
vector<long long> ln_G = poly_ln(G, n);
// 2. ?? Q * ln(G) mod x^n
vector<long long> A(n, 0);
for (int i = 0; i < ln_G.size() && i < n; i++) {
// ln_G[0] = 0, ?? A[0] = 0
A[i] = (ln_G[i] * Q) % MOD;
}
// 3. ?? exp(A) mod x^n
return poly_exp(A, n);
}
void solve() {
int N;
long long P;
cin >> N >> P;
// ?? N ? 5e5??????
precompute_factorials(N);
// --- 1. ?????? ---
// N ???? N ? 1 (mod 4)
if (N % 4 != 1) {
cout << 0 << endl;
return;
}
// ?? N_U (????) ? N_T (?????)
long long N_U = (N - 1) / 4;
long long N_T = (3LL * N + 1) / 4;
// Q ???????
long long Q = N_T - P;
// K ? c_i = d_i - 1 ???
// K = sum(d_i) - Q = (N-1 - 8P) - Q
// K = N - 1 - 8P - (N_T - P) = N - 1 - 7P - N_T
long long K_val = N - 1 - 7 * P - N_T;
// P ??????P <= (N-5)/28
// K >= 0 ?????
if (K_val < 0) {
cout << 0 << endl;
return;
}
// Q ????Q >= 0
if (Q < 0) {
cout << 0 << endl;
return;
}
// c_i <= 6 ????K <= 6Q
if (K_val > 6 * Q) {
cout << 0 << endl;
return;
}
int K = (int)K_val;
// --- 2. ????? ---
long long ans = 1;
// ????: C(N, N_T) * C(N_T, P)
ans = (ans * nCr(N, N_T)) % MOD;
ans = (ans * nCr((int)N_T, (int)P)) % MOD;
// ???: (N-2)! / ((7!)^P * (3!)^N_U)
// (N-2)!
if (N >= 2) {
ans = (ans * fact[N - 2]) % MOD;
} else { // N=1, N=2 ????? N ? 1 (mod 4)
// N >= 5
}
// 7! ? 3! ???
long long inv7Fact = invFact[7];
long long inv3Fact = invFact[3];
// (7!)^{-P}
ans = (ans * power(inv7Fact, P)) % MOD;
// (3!)^{-N_U}
ans = (ans * power(inv3Fact, N_U)) % MOD;
// --- 3. ????? T_Q ??? ---
// G(x) = sum_{c=0}^6 x^c / c!
vector<long long> G(7);
for (int c = 0; c <= 6; c++) {
G[c] = invFact[c];
}
// Q = N_T - P
// T_Q = [x^K] G(x)^Q
// ???????? H(x) = G(x)^Q mod x^{K+1}
// ???????? n = K + 1
int n = K + 1;
vector<long long> H = poly_q_pow(G, Q, n);
// ???? T_Q = H[K]
long long T_Q = H[K];
// --- 4. ???? ---
ans = (ans * T_Q) % MOD;
cout << ans << endl;
}
// ???
int main() {
// ??????? I/O ??
ios_base::sync_with_stdio(false);
cin.tie(NULL);
solve();
return 0;
}
vjudge1