結果
| 問題 | No.3394 Big Binom |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2025-12-01 20:26:16 |
| 言語 | C++23 (gcc 13.3.0 + boost 1.89.0) |
| 結果 |
AC
|
| 実行時間 | 258 ms / 2,000 ms |
| コード長 | 27,880 bytes |
| 記録 | |
| コンパイル時間 | 6,700 ms |
| コンパイル使用メモリ | 337,860 KB |
| 実行使用メモリ | 19,448 KB |
| 最終ジャッジ日時 | 2025-12-01 20:26:26 |
| 合計ジャッジ時間 | 9,400 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 4 |
| other | AC * 21 |
ソースコード
#include <bits/stdc++.h>
using namespace std;
//#define int long long
///You can add it if you want
const int md = 998244353;
mt19937 rnd;
#define app push_back
#define all(x) (x).begin(),(x).end()
#ifdef LOCAL
#define debug(...) [](auto...a){ ((cout << a << ' '), ...) << endl;}(#__VA_ARGS__, ":", __VA_ARGS__)
#define debugv(v) do {cout<< #v <<" : {"; for(int izxc=0;izxc<v.size();++izxc) {cout << v[izxc];if(izxc+1!=v.size()) cout << ","; }cout <<"}"<< endl;} while(0)
#else
#define debug(...)
#define debugv(v)
#endif
#define lob(a,x) lower_bound(all(a),x)
#define upb(a,x) upper_bound(all(a),x)
template<int M, int K, int G>
struct Fft
{
// 1, 1/4, 1/8, 3/8, 1/16, 5/16, 3/16, 7/16, ...
int g[1 << (K - 1)];
Fft() : g()
{ //if tl constexpr...
// static_assert(K >= 2, "Fft: K >= 2 must hold");
g[0] = 1;
g[1 << (K - 2)] = G;
for (int l = 1 << (K - 2); l >= 2; l >>= 1)
{
g[l >> 1] = (g[l] * 1LL * g[l]) % M;
}
assert((g[1]*1LL * g[1]) % M == M - 1);
for (int l = 2; l <= 1 << (K - 2); l <<= 1)
{
for (int i = 1; i < l; ++i)
{
g[l + i] = (g[l] * 1LL * g[i]) % M;
}
}
}
void fft(vector<int> &x) const
{
const int n = x.size();
assert(n <= 1 << K);
for (int h = __builtin_ctz(n); h--;)
{
const int l = (1 << h);
for (int i = 0; i < n >> (h + 1); ++i)
{
for (int j = i << (h + 1); j < (((i << 1) + 1) << h); ++j)
{
const int t = (g[i] * 1LL * x[j | l]) % M;
x[j | l] = x[j] - t;
if (x[j | l] < 0)
x[j | l] += M;
x[j] += t;
if (x[j] >= M)
x[j] -= M;
}
}
}
for (int i = 0, j = 0; i < n; ++i)
{
if (i < j)
std::swap(x[i], x[j]);
for (int l = n; (l >>= 1) && !((j ^= l) & l);)
{
}
}
}
vector<int> convolution(vector<int> a, vector<int> b) const
{
if (a.empty() || b.empty())
return {};
const int p = M;
for (int &x: a)
{
x %= p;
if (x >= p)
x -= p;
if (x < 0)
x += p;
}
for (int &x: b)
{
x %= p;
if (x >= p)
x -= p;
if (x < 0)
x += p;
}
const int na = a.size(), nb = b.size();
int n, invN = 1;
for (n = 1; n < na + nb - 1; n <<= 1)
invN = ((invN & 1) ? (invN + M) : invN) >> 1;
vector<int> x(n, 0), y(n, 0);
std::copy(a.begin(), a.end(), x.begin());
std::copy(b.begin(), b.end(), y.begin());
fft(x);
fft(y);
for (int i = 0; i < n; ++i)
x[i] = (((static_cast<long long>(x[i]) * y[i]) % M) * invN) % M;
std::reverse(x.begin() + 1, x.end());
fft(x);
x.resize(na + nb - 1);
return x;
}
};
Fft<998244353, 21, 31 * 31 * 31 * 31> muls;
template<int32_t MOD>
struct ModInt
{
int32_t value;
ModInt() : value(0)
{
}
ModInt(long long v) : value(v % MOD)
{
if (value < 0)
value += MOD;
}
ModInt(int32_t v): value(v % MOD)
{
if (value < 0)
value += MOD;
}
ModInt operator+=(ModInt m)
{
value += m.value;
if (value >= MOD)
value -= MOD;
return value;
}
ModInt operator-=(ModInt m)
{
value -= m.value;
if (value < 0)
value += MOD;
return value;
}
ModInt operator*=(ModInt m)
{
value = (value * 1LL * m.value) % MOD;
return value;
}
ModInt power(long long exp) const
{
if (exp == 0)
return 1;
ModInt res = (exp & 1 ? value : 1);
ModInt half = power(exp >> 1);
return res * half * half;
}
ModInt operator/=(ModInt m) { return *this *= m.power(MOD - 2); }
friend std::istream &operator>>(std::istream &is, ModInt &m)
{
is >> m.value;
return is;
}
friend std::ostream &operator<<(std::ostream &os, const ModInt &m)
{
os << m.value;
return os;
}
explicit operator int32_t() const { return value; }
explicit operator long long() const { return value; }
static int32_t mod() { return MOD; }
};
template<int32_t MOD>
ModInt<MOD> operator+(ModInt<MOD> a, ModInt<MOD> b) { return a += b; }
template<int32_t MOD, typename L>
ModInt<MOD> operator+(L a, ModInt<MOD> b) { return ModInt<MOD>(a) += b; }
template<int32_t MOD, typename R>
ModInt<MOD> operator+(ModInt<MOD> a, R b) { return a += b; }
template<int32_t MOD>
ModInt<MOD> operator-(ModInt<MOD> a, ModInt<MOD> b) { return a -= b; }
template<int32_t MOD, typename L>
ModInt<MOD> operator-(L a, ModInt<MOD> b) { return ModInt<MOD>(a) -= b; }
template<int32_t MOD, typename R>
ModInt<MOD> operator-(ModInt<MOD> a, R b) { return a -= b; }
template<int32_t MOD>
ModInt<MOD> operator*(ModInt<MOD> a, ModInt<MOD> b) { return a *= b; }
template<int32_t MOD, typename L>
ModInt<MOD> operator*(L a, ModInt<MOD> b) { return ModInt<MOD>(a) *= b; }
template<int32_t MOD, typename R>
ModInt<MOD> operator*(ModInt<MOD> a, R b) { return a *= b; }
template<int32_t MOD>
ModInt<MOD> operator/(ModInt<MOD> a, ModInt<MOD> b) { return a /= b; }
template<int32_t MOD, typename L>
ModInt<MOD> operator/(L a, ModInt<MOD> b) { return ModInt<MOD>(a) /= b; }
template<int32_t MOD, typename R>
ModInt<MOD> operator/(ModInt<MOD> a, R b) { return a /= b; }
template<int32_t MOD>
bool operator==(ModInt<MOD> a, ModInt<MOD> b) { return a.value == b.value; }
template<int32_t MOD, typename L>
bool operator==(L a, ModInt<MOD> b) { return a == b.value; }
template<int32_t MOD, typename R>
bool operator==(ModInt<MOD> a, R b) { return a.value == b; }
template<int32_t MOD>
bool operator!=(ModInt<MOD> a, ModInt<MOD> b) { return a.value != b.value; }
template<int32_t MOD, typename L>
bool operator!=(L a, ModInt<MOD> b) { return a != b.value; }
template<int32_t MOD, typename R>
bool operator!=(ModInt<MOD> a, R b) { return a.value != b; }
using mint = ModInt<md>;
mint inv(mint x) { return 1 / x; }
__int128 gcd(__int128 a, __int128 b, __int128 &x, __int128 &y)
{
if (b == 0)
{
x = 1;
y = 0;
return a;
}
__int128 d = gcd(b, a % b, y, x);
y -= a / b * x;
return d;
}
__int128 inv(__int128 r, __int128 m)
{
__int128 x, y;
gcd(r, m, x, y);
return (x + m) % m;
}
__int128 crt(__int128 r, __int128 n, __int128 c, __int128 m) { return r + ((c - r) % m + m) * inv(n, m) % m * n; }
const int m2 = 167772161, m3 = 469762049;
Fft<m2, 21, 147771621> muls2;
Fft<m3, 21, 297449090> muls3;
vector<mint> operator*(vector<mint> a, vector<mint> b)
{ ///modulo-dependent convolution
if (a.empty() || b.empty())
return {};
if (md == 998244353)
{
vector<int> a1(a.size());
for (int i = 0; i < a.size(); ++i)
a1[i] = a[i].value;
vector<int> b1(b.size());
for (int i = 0; i < b.size(); ++i)
b1[i] = b[i].value;
vector<int> c1 = muls.convolution(a1, b1);
vector<mint> c;
for (int x: c1)
c.app(x);
return c;
}
else
{
vector<int> a1(a.size());
for (int i = 0; i < a.size(); ++i)
a1[i] = a[i].value;
vector<int> b1(b.size());
for (int i = 0; i < b.size(); ++i)
b1[i] = b[i].value;
vector<int> c1 = muls.convolution(a1, b1);
vector<int> c2 = muls2.convolution(a1, b1);
vector<int> c3 = muls3.convolution(a1, b1);
assert(c1.size()==c2.size() && c2.size()==c3.size());
vector<int> c4(c1.size());
for (int i = 0; i < c1.size(); ++i)
{
__int128 ost1 = c1[i];
__int128 m1 = 998244353;
__int128 ost2 = c2[i];
__int128 ost3 = c3[i];
__int128 ost = crt(crt(ost1, m1, ost2, m2), m1 * 1LL * m2, ost3, m3);
c4[i] = (ost % md);
}
vector<mint> c;
for (int x: c4)
c.app(x);
return c;
}
}
vector<vector<mint> > gaussbasis(vector<vector<mint> > A) ///returns basis of Av=0
{
int n = A.size();
int m = A[0].size();
int bi = 0;
for (int i = 0; i < n; ++i)
{
if (bi == m)
break;
for (int j = i; j < n; ++j)
{
if (A[j][bi] != 0)
{
if (j != i) { swap(A[i], A[j]); }
break;
}
}
if (A[i][bi] != 0)
{
mint o = 1 / A[i][bi];
for (int j = i + 1; j < n; ++j)
{
mint we = (A[j][bi] * o);
for (int k = bi; k < m; ++k)
{
A[j][k] -= we * A[i][k];
}
}
}
else
{
++bi;
--i;
continue;
}
}
vector<int> indices(m);
iota(all(indices), 0);
for (int i = n - 1; i >= 0; --i)
{
int bi = 0;
while (bi < m && A[i][bi] == 0) { ++bi; }
if (bi < m)
{
indices.erase(find(all(indices), bi));
}
}
vector<vector<mint> > v(indices.size(), vector<mint>(m, 0));
for (int i = 0; i < indices.size(); ++i)
{
v[i][indices[i]] = 1;
}
for (int i = n - 1; i >= 0; --i)
{
int bi = 0;
while (bi < m && A[i][bi] == 0) { ++bi; }
if (bi == m)
continue;
for (int k = 0; k < indices.size(); ++k)
{
mint cur = 0;
for (int j = bi + 1; j < m; ++j)
{
cur -= A[i][j] * v[k][j];
}
v[k][bi] = cur / A[i][bi];
}
}
return v;
}
optional<vector<mint> > gauss(vector<vector<mint> > A, vector<mint> b) ///returns v such that Av=b
{
int n = A.size();
assert(b.size()==n);
int m = A[0].size();
int bi = 0;
for (int i = 0; i < n; ++i)
{
if (bi == m)
break;
for (int j = i; j < n; ++j)
{
if (A[j][bi] != 0)
{
if (j != i)
{
swap(A[i], A[j]);
swap(b[i], b[j]);
}
break;
}
}
if (A[i][bi] != 0)
{
mint o = inv(A[i][bi]);
for (int j = i + 1; j < n; ++j)
{
mint we = (A[j][bi] * o);
b[j] -= we * b[i];
for (int k = bi; k < m; ++k)
{
A[j][k] -= we * A[i][k];
}
}
}
else
{
++bi;
--i;
continue;
}
}
vector<mint> v(m);
for (int i = n - 1; i >= 0; --i)
{
int bi = 0;
while (bi < m && A[i][bi] == 0) { ++bi; }
if (bi == m)
{
if (b[i] != 0) { return nullopt; }
else { continue; }
} {
mint cur = b[i];
for (int j = bi + 1; j < m; ++j)
{
cur -= A[i][j] * v[j];
}
v[bi] = cur * inv(A[i][bi]);
}
}
return v;
}
optional<vector<vector<mint> > > findPrecursion(vector<mint> a)
{ ///finds P-recursion of a given sequence A by gauss
for (int snd = 0; snd <= 20; ++snd)
{
for (int n = 1; n <= snd - 1; ++n)
{
vector<vector<mint> > A;
int d = snd - n;
int eq = ((int) (a.size())) - (n - 1);
if (eq < n * d) { continue; }
for (int i = n - 1; i < a.size(); ++i)
{
vector<mint> u;
for (int j = 0; j < n; ++j)
{
mint de = 1;
for (int k = 0; k < d; ++k)
{
u.app(a[i - j] * de);
de *= i;
}
}
A.app(u);
}
vector<vector<mint> > zx = gaussbasis(A);
if (zx.empty())
continue;
//debug(n, d);
vector<mint> ans = zx[0];
vector<vector<mint> > res;
for (int j = 0; j < n; ++j)
{
res.app({});
for (int k = 0; k < d; ++k)
{
res[j].app(ans[j * d + k]);
}
}
return res;
}
}
return nullopt;
}
optional<vector<mint> > evaluatePrecursion(vector<mint> a, vector<vector<mint> > rec, int sz)
{ ///a(0),...,a(a.size()-1) -> (by P-recursion rec) a(0),...,a(sz-1)
int n = rec.size();
int d = rec[0].size();
int given = a.size();
if (given >= sz)
{
a.resize(sz);
return a;
}
if (a.size() < n) { return nullopt; }
vector<mint> tore;
for (int i = given; i < sz; ++i)
{
mint de = 1;
mint s = 0;
for (int k = 0; k < d; ++k)
{
s += de * rec[0][k];
de *= i;
}
if (s == 0) { return nullopt; }
tore.app(s);
}
vector<mint> pref(tore.size() + 1);
pref[0] = 1;
for (int i = 0; i < tore.size(); ++i) { pref[i + 1] = pref[i] * tore[i]; }
mint pro = pref[tore.size()];
mint invpro = 1 / pro;
mint cur = invpro;
vector<mint> invtore(tore.size());
for (int i = tore.size() - 1; i >= 0; --i)
{
invtore[i] = cur * pref[i];
cur *= tore[i];
}
for (int i = given; i < sz; ++i)
{
mint chi = 0;
for (int j = 1; j < n; ++j)
{
mint de = 1;
for (int k = 0; k < d; ++k)
{
chi += a[i - j] * de * rec[j][k];
de *= i;
}
}
a.app(((mint) (0)) - chi * invtore[i - given]);
}
return a;
}
mint value(vector<mint> a, mint x)
{ ///A(x)
mint de = 1;
mint ans = 0;
for (int i = 0; i < a.size(); ++i)
{
ans += a[i] * de;
de *= x;
}
return ans;
}
vector<mint> shiftofsamplingpoints(vector<mint> a)
{ ///P(0),...,P(t) we want to compute P(0),...,P(4t+1)
int t = a.size() - 1;
vector<mint> fact(4 * t + 2);
fact[0] = 1;
for (int i = 1; i < 4 * t + 2; ++i)
fact[i] = fact[i - 1] * i;
vector<mint> invf(4 * t + 2);
invf[4 * t + 1] = 1 / fact[4 * t + 1];
for (int i = 4 * t; i >= 0; --i) { invf[i] = (invf[i + 1] * (i + 1)); }
assert(invf[0]==1);
vector<mint> invm(4 * t + 2, 0);
for (int i = 1; i < 4 * t + 2; ++i) { invm[i] = fact[i - 1] * invf[i]; }
vector<mint> values(t + 1, 0);
for (int k = 0; k <= t; ++k)
{
mint o = 1;
if ((t - k) % 2 == 1) { o = (((mint) (0)) - 1); }
values[k] = (a[k] * invf[k] * invf[t - k] * o);
}
vector<mint> h = invm * values;
vector<mint> res;
for (int i = 0; i <= t; ++i)
{
res.app(a[i]);
}
for (int x = t + 1; x <= 4 * t + 1; ++x)
{
mint ans = fact[x];
ans *= invf[x - t - 1];
ans *= h[x];
res.app(ans);
}
return res;
}
optional<mint> evaluatePrecursionfast(vector<mint> a, vector<vector<mint> > rec, int id)
{ ///a(0),...,a(a.size()-1) -> (by P-recursion rec) a(id), O(sqrt(id)*log(id))
if (id < a.size())
return a[id];
int n = rec.size();
int d = 1;
for (auto &v: rec) { d = max(d, ((int) (v.size() - 1))); }
if (a.size() < n - 1) { return nullopt; }
if (n == 1) { return 0; }
int l = n - 1;
int shift = 0;
while (a.size() > l)
{
a.erase(a.begin());
++shift;
--id;
}
int u = 0;
while ((1LL << u) * (1LL << u) <= id) { ++u; }
int B = (1 << u);
vector<mint> S;
vector<vector<vector<mint> > > A(l, vector<vector<mint> >(l));
int sz = d;
for (int k = 0; k <= d; ++k)
{
S.app(value(rec[0], k + l + shift));
}
for (int i = 0; i < l - 1; ++i)
{
for (int j = 0; j < l; ++j)
{
if (j == i + 1)
{
for (int k = 0; k <= d; ++k)
{
A[i][j].app(value(rec[0], k + l + shift));
}
}
else
{
for (int k = 0; k <= d; ++k)
{
A[i][j].app(0);
}
}
}
}
for (int j = 0; j < l; ++j)
{
for (int k = 0; k <= d; ++k)
{
A[l - 1][j].app(((mint) (0)) - value(rec[l - j], k + l + shift));
}
}
for (int s = 0; s < u; ++s)
{
S = shiftofsamplingpoints(S);
assert(S.size()==4*sz+2);
for (int i = 0; i < l; ++i)
{
for (int j = 0; j < l; ++j)
{
A[i][j] = shiftofsamplingpoints(A[i][j]);
assert(A[i][j].size()==4*sz+2);
}
}
vector<vector<vector<mint> > > newA(l, vector<vector<mint> >(l, vector<mint>(2 * sz + 1, 0)));
for (int k = 0; k <= 2 * sz; ++k)
{
for (int ii = 0; ii < l; ++ii)
{
for (int jj = 0; jj < l; ++jj)
{
for (int kk = 0; kk < l; ++kk)
{
newA[ii][kk][k] += A[ii][jj][2 * k + 1] * A[jj][kk][2 * k];
}
}
}
}
vector<mint> newS(2 * sz + 1, 0);
for (int k = 0; k <= 2 * sz; ++k) { newS[k] = S[2 * k] * S[2 * k + 1]; }
sz *= 2;
S = newS;
A = newA;
}
int k = (id - l) / B;
assert(k>=0); ///id>=l+1 here
mint pro = 1;
vector<mint> v;
for (int i = 0; i < l; ++i)
v.app(a[i]);
for (int i = 0; i < k; ++i)
{
vector<mint> newv(l, 0);
vector<vector<mint> > M(l, vector<mint>(l, 0));
for (int ii = 0; ii < l; ++ii)
{
for (int jj = 0; jj < l; ++jj)
{
assert(i<A[ii][jj].size());
M[ii][jj] = A[ii][jj][i];
}
}
pro *= S[i];
for (int ii = 0; ii < l; ++ii)
{
for (int jj = 0; jj < l; ++jj)
{
newv[ii] += M[ii][jj] * v[jj];
}
}
v = newv;
}
int cur = k * B + l - 1;
assert(cur<id);
while (cur < id)
{
mint newval = 0;
for (int i = 0; i < l; ++i)
{
newval -= v[i] * value(rec[l - i], cur + 1 + shift);
}
mint de = value(rec[0], cur + 1 + shift);
pro *= de;
for (mint &x: v) { x *= de; }
v.erase(v.begin());
v.app(newval);
++cur;
}
if (pro == 0)
return nullopt;
return v.back() / pro;
}
optional<vector<mint> > sequenceextender(vector<mint> a, int sz)
{ ///finds P-recursion, and if was found, calculates a(0),...,a(sz-1)
if (a.size() >= sz)
{
a.resize(sz);
return a;
}
auto uu = findPrecursion(a);
if (!uu)
return nullopt;
auto rec = (*uu);
auto ans = evaluatePrecursion(a, rec, sz);
if (!ans)
return nullopt;
return (*ans);
}
optional<mint> fastgetvaluebyid(vector<mint> a, int id)
{ ///finds P-recursion, and if was found, calculates a(id) in O(sqrt(id)*log(id))
if (a.size() > id) { return a[id]; }
auto uu = findPrecursion(a);
if (!uu)
return nullopt;
auto rec = (*uu);
auto ans = evaluatePrecursionfast(a, rec, id);
if (!ans)
return nullopt;
return (*ans);
}
optional<mint> optimalgetvaluebyid(vector<mint> a, int id)
{ ///finds P-recursion, and if was found, calculates a(id) by choosing optimal of O(id) method and O(sqrt(id)*log(id)) method
if (a.size() > id) { return a[id]; }
auto uu = findPrecursion(a);
if (!uu)
return nullopt;
auto rec = (*uu);
int n = rec.size();
int d = rec[0].size();
double C = 1;
if (md != 998244353)
C = 3;
double val1 = sqrt(id) * log(id) * C * n * n * d + sqrt(id) * n * n * n * d;
double val2 = n * 1.0 * d * 1.0 * id;
//debug(val1, val2);
if (val1 < val2)
{
//debug("fastgetvalue");
auto ans = evaluatePrecursionfast(a, rec, id);
if (!ans)
return nullopt;
return (*ans);
}
else
{
//debug("sequenceextender");
auto ans = evaluatePrecursion(a, rec, id + 1);
if (!ans)
return nullopt;
return (*ans)[id];
}
}
vector<vector<mint> > transpose(vector<vector<mint> > a)
{ ///transposes the table A
if (a.empty())
return a;
int n = a.size();
int m = a[0].size();
vector<vector<mint> > b(m, vector<mint>(n, 0));
for (int i = 0; i < n; ++i)
for (int j = 0; j < m; ++j)
b[j][i] = a[i][j];
return b;
}
/// If the size of a is big TL (too many Gauss), if the size of a is small WA (not enough for finding the P-recursion), you should keep balance
optional<vector<mint> > extendtableslow(vector<vector<mint> > a, vector<pair<int, int> > que)
{ /// extends table A, finding A(que[i].first,que[i].second), in a O(sum(que[i]))
if (que.empty())
{
vector<mint> ret = {};
return ret;
}
int ma = 0;
for (auto [i,j]: que) { ma = max(ma, j); }
int n = a.size();
int m = a[0].size();
vector<vector<mint> > ex(n);
for (int i = 0; i < n; ++i)
{
auto h = sequenceextender(a[i], ma + 1);
if (!h) { return nullopt; }
ex[i] = (*h);
}
vector<mint> res;
for (auto [i,j]: que)
{
vector<mint> e;
for (int k = 0; k < n; ++k)
{
e.app(ex[k][j]);
}
auto h = sequenceextender(e, i + 1);
if (!h) { return nullopt; }
res.app((*h)[i]);
}
return res;
}
optional<vector<mint> > extendtablefast(vector<vector<mint> > a, vector<pair<int, int> > que)
{ /// extends table A, finding A(que[i].first,que[i].second)
if (que.empty())
{
vector<mint> ret = {};
return ret;
}
int n = a.size();
int m = a[0].size();
vector<vector<mint> > ex(n);
vector<vector<mint> > rec[n];
for (int k = 0; k < n; ++k)
{
auto uu = findPrecursion(a[k]);
if (!uu) { return nullopt; }
rec[k] = (*uu);
}
vector<mint> res;
for (auto [i,j]: que)
{
vector<mint> e;
for (int k = 0; k < n; ++k)
{
auto uu = evaluatePrecursionfast(a[k], rec[k], j);
if (!uu) { return nullopt; }
e.app(*uu);
}
auto h = optimalgetvaluebyid(e, i);
if (!h) { return nullopt; }
res.app(*h);
}
return res;
}
optional<vector<mint> > extendtable(vector<vector<mint> > a, vector<pair<int, int> > que)
{ /// extends table A, finding A(que[i].first,que[i].second)
double C = 1;
if (md != 998244353) { C = 3; }
double op1 = 0;
double op2 = 0;
for (auto [i,j]: que)
{
op1 += C * 1.0 * ((int) (a.size())) * 1.0 * sqrt(i + 1) * 1.0 * log(i + 2) * 5 * 5 * 5;
op1 += C * sqrt(j + 1) * log(j + 2) * 5 * 5 * 5;
}
for (auto [i,j]: que)
{
op2 += C * 1.0 * ((int) (a.size())) * 1.0 * (i + 1) * 1.0 * 5 * 5;
op2 += C * 1.0 * (j + 1) * 1.0 * 5 * 5;
}
//debug(op1, op2);
if (op1 < op2) { return extendtablefast(a, que); }
else { return extendtableslow(a, que); }
}
optional<vector<mint> > getcolumnoftable(vector<vector<mint> > a, int col, int size)
{ /// get A(0,col),A(1,col),...,A(size-1,col)
int n = a.size();
int m = a[0].size();
vector<mint> e;
for (int i = 0; i < n; ++i)
{
auto h = sequenceextender(a[i], col + 1);
if (!h) { return nullopt; }
e.app((*h)[col]);
}
auto h = sequenceextender(e, size);
if (!h)
return nullopt;
return *h;
}
optional<vector<mint> > getrowoftable(vector<vector<mint> > a, int row, int size)
{ /// get A(row,0),A(row,1),...,A(row,size-1)
return getcolumnoftable(transpose(a), row, size);
}
void test()
{
vector<mint> v1 = {1, 1, 3, 7, 19, 51, 141, 393, 1107, 3139}; ///(1+x+x^2)^n [x^n]
auto uu1 = sequenceextender(v1, 30);
if (uu1)
{
auto u1 = (*uu1);
debugv(u1);
}
vector<mint> v2 = {1, 30, 465, 4930, 40020, 264306, 1474795, 7133130, 30462615, 116470380}; ///(1+x+x^2)^30 [x^n]
auto uu2 = sequenceextender(v2, 70);
if (uu2)
{
auto u2 = (*uu2);
debugv(u2);
}
vector<mint> v3 = {0, 567646151, 513265721, 604121291, 715018514, 398975714, 610803800, 499563577, 491416403,
913506524
}; ///s(30,n), not D-finite
auto uu3 = sequenceextender(v3, 35);
if (uu3)
{
auto u3 = (*uu3);
debugv(u3);
}
vector<mint> v4 = {0, 1, 2, 9, 44, 265, 1854, 14833, 133496, 1334961, 14684570};
///n!*x*e^(-x) [x^n] (number of permutations of size n without stable points)
auto uu4 = sequenceextender(v4, 20);
if (uu4)
{
auto u4 = (*uu4);
debugv(u4);
}
auto uu5 = fastgetvaluebyid({1, 1, 2, 6, 24, 120, 720}, 11); ///factorials ,by id
if (uu5)
{
auto u5 = (*uu5);
debug(u5);
}
auto uu6 = fastgetvaluebyid(v1, 29); ///(1+x+x^2)^n [x^n], by id
if (uu6)
{
auto u6 = (*uu6);
debug(u6);
}
auto uu7 = fastgetvaluebyid(v1, 500000000); ///(1+x+x^2)^n [x^n], by id
if (uu7)
{
auto u7 = (*uu7);
debug(u7);
}
auto uu8 = optimalgetvaluebyid(v1, 500000000); ///(1+x+x^2)^n [x^n], by id
if (uu8)
{
auto u8 = (*uu8);
debug(u8);
}
auto uu9 = optimalgetvaluebyid({1, 1, 2, 6, 24, 120, 720}, 11); ///factorials ,by id
if (uu9)
{
auto u9 = (*uu9);
debug(u9);
}
auto uu10 = optimalgetvaluebyid({1, 1, 2, 6, 24, 120, 720}, 998244352); ///factorials ,by id
if (uu10)
{
auto u10 = (*uu10);
debug(u10);
}
vector<vector<mint> > table1 =
{{1, 2, 3, 4, 5}, {2, 3, 4, 5, 6}, {3, 4, 5, 6, 7}, {4, 5, 6, 7, 8}, {5, 6, 7, 8, 9}}; ///f(i,j)=i+j+1
vector<pair<int, int> > que1 = {{1, 1}, {0, 0}, {5, 7}, {11342, 1333}};
auto uu11 = extendtablefast(table1, que1);
auto uu12 = extendtableslow(table1, que1);
auto uu13 = extendtable(table1, que1);
debug((bool) (uu11));
debug((bool) (uu12));
debug((bool) (uu13));
if (uu11 && uu12 && uu13)
{
auto u11 = (*uu11);
auto u12 = (*uu12);
auto u13 = (*uu13);
debugv(u11);
debugv(u12);
debugv(u13);
}
auto uu14 = getcolumnoftable(table1, 10, 100);
if (uu14)
{
auto u14 = (*uu14);
debugv(u14);
}
auto uu15 = getrowoftable(table1, 50, 70);
if (uu15)
{
auto u15 = (*uu15);
debugv(u15);
}
///exp((x+y)/((1-x)(1-y)), I removed it from examples
}
int32_t main()
{
ios_base::sync_with_stdio(false);
cin.tie(0);
int n,k;cin>>n>>k;
k=min(k,n-k);
if(n>=md) {n-=md;}
if(n<k) {
cout<<0;
return 0;
}
auto uu5 = fastgetvaluebyid({1, 1, 2, 6, 24, 120, 720}, n); ///factorials ,by id
auto uu6 = fastgetvaluebyid({1, 1, 2, 6, 24, 120, 720}, k); ///factorials ,by id
auto uu7 = fastgetvaluebyid({1, 1, 2, 6, 24, 120, 720}, n-k); ///factorials ,by id
mint ans=((*uu5)/(*uu6))/(*uu7);
cout<<ans;
return 0;
}