#include using namespace std; //#define int long long ///You can add it if you want const int md = 998244353; mt19937 rnd; #define app push_back #define all(x) (x).begin(),(x).end() #ifdef LOCAL #define debug(...) [](auto...a){ ((cout << a << ' '), ...) << endl;}(#__VA_ARGS__, ":", __VA_ARGS__) #define debugv(v) do {cout<< #v <<" : {"; for(int izxc=0;izxc struct Fft { // 1, 1/4, 1/8, 3/8, 1/16, 5/16, 3/16, 7/16, ... int g[1 << (K - 1)]; Fft() : g() { //if tl constexpr... // static_assert(K >= 2, "Fft: K >= 2 must hold"); g[0] = 1; g[1 << (K - 2)] = G; for (int l = 1 << (K - 2); l >= 2; l >>= 1) { g[l >> 1] = (g[l] * 1LL * g[l]) % M; } assert((g[1]*1LL * g[1]) % M == M - 1); for (int l = 2; l <= 1 << (K - 2); l <<= 1) { for (int i = 1; i < l; ++i) { g[l + i] = (g[l] * 1LL * g[i]) % M; } } } void fft(vector &x) const { const int n = x.size(); assert(n <= 1 << K); for (int h = __builtin_ctz(n); h--;) { const int l = (1 << h); for (int i = 0; i < n >> (h + 1); ++i) { for (int j = i << (h + 1); j < (((i << 1) + 1) << h); ++j) { const int t = (g[i] * 1LL * x[j | l]) % M; x[j | l] = x[j] - t; if (x[j | l] < 0) x[j | l] += M; x[j] += t; if (x[j] >= M) x[j] -= M; } } } for (int i = 0, j = 0; i < n; ++i) { if (i < j) std::swap(x[i], x[j]); for (int l = n; (l >>= 1) && !((j ^= l) & l);) { } } } vector convolution(vector a, vector b) const { if (a.empty() || b.empty()) return {}; const int p = M; for (int &x: a) { x %= p; if (x >= p) x -= p; if (x < 0) x += p; } for (int &x: b) { x %= p; if (x >= p) x -= p; if (x < 0) x += p; } const int na = a.size(), nb = b.size(); int n, invN = 1; for (n = 1; n < na + nb - 1; n <<= 1) invN = ((invN & 1) ? (invN + M) : invN) >> 1; vector x(n, 0), y(n, 0); std::copy(a.begin(), a.end(), x.begin()); std::copy(b.begin(), b.end(), y.begin()); fft(x); fft(y); for (int i = 0; i < n; ++i) x[i] = (((static_cast(x[i]) * y[i]) % M) * invN) % M; std::reverse(x.begin() + 1, x.end()); fft(x); x.resize(na + nb - 1); return x; } }; Fft<998244353, 21, 31 * 31 * 31 * 31> muls; template struct ModInt { int32_t value; ModInt() : value(0) { } ModInt(long long v) : value(v % MOD) { if (value < 0) value += MOD; } ModInt(int32_t v): value(v % MOD) { if (value < 0) value += MOD; } ModInt operator+=(ModInt m) { value += m.value; if (value >= MOD) value -= MOD; return value; } ModInt operator-=(ModInt m) { value -= m.value; if (value < 0) value += MOD; return value; } ModInt operator*=(ModInt m) { value = (value * 1LL * m.value) % MOD; return value; } ModInt power(long long exp) const { if (exp == 0) return 1; ModInt res = (exp & 1 ? value : 1); ModInt half = power(exp >> 1); return res * half * half; } ModInt operator/=(ModInt m) { return *this *= m.power(MOD - 2); } friend std::istream &operator>>(std::istream &is, ModInt &m) { is >> m.value; return is; } friend std::ostream &operator<<(std::ostream &os, const ModInt &m) { os << m.value; return os; } explicit operator int32_t() const { return value; } explicit operator long long() const { return value; } static int32_t mod() { return MOD; } }; template ModInt operator+(ModInt a, ModInt b) { return a += b; } template ModInt operator+(L a, ModInt b) { return ModInt(a) += b; } template ModInt operator+(ModInt a, R b) { return a += b; } template ModInt operator-(ModInt a, ModInt b) { return a -= b; } template ModInt operator-(L a, ModInt b) { return ModInt(a) -= b; } template ModInt operator-(ModInt a, R b) { return a -= b; } template ModInt operator*(ModInt a, ModInt b) { return a *= b; } template ModInt operator*(L a, ModInt b) { return ModInt(a) *= b; } template ModInt operator*(ModInt a, R b) { return a *= b; } template ModInt operator/(ModInt a, ModInt b) { return a /= b; } template ModInt operator/(L a, ModInt b) { return ModInt(a) /= b; } template ModInt operator/(ModInt a, R b) { return a /= b; } template bool operator==(ModInt a, ModInt b) { return a.value == b.value; } template bool operator==(L a, ModInt b) { return a == b.value; } template bool operator==(ModInt a, R b) { return a.value == b; } template bool operator!=(ModInt a, ModInt b) { return a.value != b.value; } template bool operator!=(L a, ModInt b) { return a != b.value; } template bool operator!=(ModInt a, R b) { return a.value != b; } using mint = ModInt; mint inv(mint x) { return 1 / x; } __int128 gcd(__int128 a, __int128 b, __int128 &x, __int128 &y) { if (b == 0) { x = 1; y = 0; return a; } __int128 d = gcd(b, a % b, y, x); y -= a / b * x; return d; } __int128 inv(__int128 r, __int128 m) { __int128 x, y; gcd(r, m, x, y); return (x + m) % m; } __int128 crt(__int128 r, __int128 n, __int128 c, __int128 m) { return r + ((c - r) % m + m) * inv(n, m) % m * n; } const int m2 = 167772161, m3 = 469762049; Fft muls2; Fft muls3; vector operator*(vector a, vector b) { ///modulo-dependent convolution if (a.empty() || b.empty()) return {}; if (md == 998244353) { vector a1(a.size()); for (int i = 0; i < a.size(); ++i) a1[i] = a[i].value; vector b1(b.size()); for (int i = 0; i < b.size(); ++i) b1[i] = b[i].value; vector c1 = muls.convolution(a1, b1); vector c; for (int x: c1) c.app(x); return c; } else { vector a1(a.size()); for (int i = 0; i < a.size(); ++i) a1[i] = a[i].value; vector b1(b.size()); for (int i = 0; i < b.size(); ++i) b1[i] = b[i].value; vector c1 = muls.convolution(a1, b1); vector c2 = muls2.convolution(a1, b1); vector c3 = muls3.convolution(a1, b1); assert(c1.size()==c2.size() && c2.size()==c3.size()); vector c4(c1.size()); for (int i = 0; i < c1.size(); ++i) { __int128 ost1 = c1[i]; __int128 m1 = 998244353; __int128 ost2 = c2[i]; __int128 ost3 = c3[i]; __int128 ost = crt(crt(ost1, m1, ost2, m2), m1 * 1LL * m2, ost3, m3); c4[i] = (ost % md); } vector c; for (int x: c4) c.app(x); return c; } } vector > gaussbasis(vector > A) ///returns basis of Av=0 { int n = A.size(); int m = A[0].size(); int bi = 0; for (int i = 0; i < n; ++i) { if (bi == m) break; for (int j = i; j < n; ++j) { if (A[j][bi] != 0) { if (j != i) { swap(A[i], A[j]); } break; } } if (A[i][bi] != 0) { mint o = 1 / A[i][bi]; for (int j = i + 1; j < n; ++j) { mint we = (A[j][bi] * o); for (int k = bi; k < m; ++k) { A[j][k] -= we * A[i][k]; } } } else { ++bi; --i; continue; } } vector indices(m); iota(all(indices), 0); for (int i = n - 1; i >= 0; --i) { int bi = 0; while (bi < m && A[i][bi] == 0) { ++bi; } if (bi < m) { indices.erase(find(all(indices), bi)); } } vector > v(indices.size(), vector(m, 0)); for (int i = 0; i < indices.size(); ++i) { v[i][indices[i]] = 1; } for (int i = n - 1; i >= 0; --i) { int bi = 0; while (bi < m && A[i][bi] == 0) { ++bi; } if (bi == m) continue; for (int k = 0; k < indices.size(); ++k) { mint cur = 0; for (int j = bi + 1; j < m; ++j) { cur -= A[i][j] * v[k][j]; } v[k][bi] = cur / A[i][bi]; } } return v; } optional > gauss(vector > A, vector b) ///returns v such that Av=b { int n = A.size(); assert(b.size()==n); int m = A[0].size(); int bi = 0; for (int i = 0; i < n; ++i) { if (bi == m) break; for (int j = i; j < n; ++j) { if (A[j][bi] != 0) { if (j != i) { swap(A[i], A[j]); swap(b[i], b[j]); } break; } } if (A[i][bi] != 0) { mint o = inv(A[i][bi]); for (int j = i + 1; j < n; ++j) { mint we = (A[j][bi] * o); b[j] -= we * b[i]; for (int k = bi; k < m; ++k) { A[j][k] -= we * A[i][k]; } } } else { ++bi; --i; continue; } } vector v(m); for (int i = n - 1; i >= 0; --i) { int bi = 0; while (bi < m && A[i][bi] == 0) { ++bi; } if (bi == m) { if (b[i] != 0) { return nullopt; } else { continue; } } { mint cur = b[i]; for (int j = bi + 1; j < m; ++j) { cur -= A[i][j] * v[j]; } v[bi] = cur * inv(A[i][bi]); } } return v; } optional > > findPrecursion(vector a) { ///finds P-recursion of a given sequence A by gauss for (int snd = 0; snd <= 20; ++snd) { for (int n = 1; n <= snd - 1; ++n) { vector > A; int d = snd - n; int eq = ((int) (a.size())) - (n - 1); if (eq < n * d) { continue; } for (int i = n - 1; i < a.size(); ++i) { vector u; for (int j = 0; j < n; ++j) { mint de = 1; for (int k = 0; k < d; ++k) { u.app(a[i - j] * de); de *= i; } } A.app(u); } vector > zx = gaussbasis(A); if (zx.empty()) continue; //debug(n, d); vector ans = zx[0]; vector > res; for (int j = 0; j < n; ++j) { res.app({}); for (int k = 0; k < d; ++k) { res[j].app(ans[j * d + k]); } } return res; } } return nullopt; } optional > evaluatePrecursion(vector a, vector > rec, int sz) { ///a(0),...,a(a.size()-1) -> (by P-recursion rec) a(0),...,a(sz-1) int n = rec.size(); int d = rec[0].size(); int given = a.size(); if (given >= sz) { a.resize(sz); return a; } if (a.size() < n) { return nullopt; } vector tore; for (int i = given; i < sz; ++i) { mint de = 1; mint s = 0; for (int k = 0; k < d; ++k) { s += de * rec[0][k]; de *= i; } if (s == 0) { return nullopt; } tore.app(s); } vector pref(tore.size() + 1); pref[0] = 1; for (int i = 0; i < tore.size(); ++i) { pref[i + 1] = pref[i] * tore[i]; } mint pro = pref[tore.size()]; mint invpro = 1 / pro; mint cur = invpro; vector invtore(tore.size()); for (int i = tore.size() - 1; i >= 0; --i) { invtore[i] = cur * pref[i]; cur *= tore[i]; } for (int i = given; i < sz; ++i) { mint chi = 0; for (int j = 1; j < n; ++j) { mint de = 1; for (int k = 0; k < d; ++k) { chi += a[i - j] * de * rec[j][k]; de *= i; } } a.app(((mint) (0)) - chi * invtore[i - given]); } return a; } mint value(vector a, mint x) { ///A(x) mint de = 1; mint ans = 0; for (int i = 0; i < a.size(); ++i) { ans += a[i] * de; de *= x; } return ans; } vector shiftofsamplingpoints(vector a) { ///P(0),...,P(t) we want to compute P(0),...,P(4t+1) int t = a.size() - 1; vector fact(4 * t + 2); fact[0] = 1; for (int i = 1; i < 4 * t + 2; ++i) fact[i] = fact[i - 1] * i; vector invf(4 * t + 2); invf[4 * t + 1] = 1 / fact[4 * t + 1]; for (int i = 4 * t; i >= 0; --i) { invf[i] = (invf[i + 1] * (i + 1)); } assert(invf[0]==1); vector invm(4 * t + 2, 0); for (int i = 1; i < 4 * t + 2; ++i) { invm[i] = fact[i - 1] * invf[i]; } vector values(t + 1, 0); for (int k = 0; k <= t; ++k) { mint o = 1; if ((t - k) % 2 == 1) { o = (((mint) (0)) - 1); } values[k] = (a[k] * invf[k] * invf[t - k] * o); } vector h = invm * values; vector res; for (int i = 0; i <= t; ++i) { res.app(a[i]); } for (int x = t + 1; x <= 4 * t + 1; ++x) { mint ans = fact[x]; ans *= invf[x - t - 1]; ans *= h[x]; res.app(ans); } return res; } optional evaluatePrecursionfast(vector a, vector > rec, int id) { ///a(0),...,a(a.size()-1) -> (by P-recursion rec) a(id), O(sqrt(id)*log(id)) if (id < a.size()) return a[id]; int n = rec.size(); int d = 1; for (auto &v: rec) { d = max(d, ((int) (v.size() - 1))); } if (a.size() < n - 1) { return nullopt; } if (n == 1) { return 0; } int l = n - 1; int shift = 0; while (a.size() > l) { a.erase(a.begin()); ++shift; --id; } int u = 0; while ((1LL << u) * (1LL << u) <= id) { ++u; } int B = (1 << u); vector S; vector > > A(l, vector >(l)); int sz = d; for (int k = 0; k <= d; ++k) { S.app(value(rec[0], k + l + shift)); } for (int i = 0; i < l - 1; ++i) { for (int j = 0; j < l; ++j) { if (j == i + 1) { for (int k = 0; k <= d; ++k) { A[i][j].app(value(rec[0], k + l + shift)); } } else { for (int k = 0; k <= d; ++k) { A[i][j].app(0); } } } } for (int j = 0; j < l; ++j) { for (int k = 0; k <= d; ++k) { A[l - 1][j].app(((mint) (0)) - value(rec[l - j], k + l + shift)); } } for (int s = 0; s < u; ++s) { S = shiftofsamplingpoints(S); assert(S.size()==4*sz+2); for (int i = 0; i < l; ++i) { for (int j = 0; j < l; ++j) { A[i][j] = shiftofsamplingpoints(A[i][j]); assert(A[i][j].size()==4*sz+2); } } vector > > newA(l, vector >(l, vector(2 * sz + 1, 0))); for (int k = 0; k <= 2 * sz; ++k) { for (int ii = 0; ii < l; ++ii) { for (int jj = 0; jj < l; ++jj) { for (int kk = 0; kk < l; ++kk) { newA[ii][kk][k] += A[ii][jj][2 * k + 1] * A[jj][kk][2 * k]; } } } } vector newS(2 * sz + 1, 0); for (int k = 0; k <= 2 * sz; ++k) { newS[k] = S[2 * k] * S[2 * k + 1]; } sz *= 2; S = newS; A = newA; } int k = (id - l) / B; assert(k>=0); ///id>=l+1 here mint pro = 1; vector v; for (int i = 0; i < l; ++i) v.app(a[i]); for (int i = 0; i < k; ++i) { vector newv(l, 0); vector > M(l, vector(l, 0)); for (int ii = 0; ii < l; ++ii) { for (int jj = 0; jj < l; ++jj) { assert(i > sequenceextender(vector a, int sz) { ///finds P-recursion, and if was found, calculates a(0),...,a(sz-1) if (a.size() >= sz) { a.resize(sz); return a; } auto uu = findPrecursion(a); if (!uu) return nullopt; auto rec = (*uu); auto ans = evaluatePrecursion(a, rec, sz); if (!ans) return nullopt; return (*ans); } optional fastgetvaluebyid(vector a, int id) { ///finds P-recursion, and if was found, calculates a(id) in O(sqrt(id)*log(id)) if (a.size() > id) { return a[id]; } auto uu = findPrecursion(a); if (!uu) return nullopt; auto rec = (*uu); auto ans = evaluatePrecursionfast(a, rec, id); if (!ans) return nullopt; return (*ans); } optional optimalgetvaluebyid(vector a, int id) { ///finds P-recursion, and if was found, calculates a(id) by choosing optimal of O(id) method and O(sqrt(id)*log(id)) method if (a.size() > id) { return a[id]; } auto uu = findPrecursion(a); if (!uu) return nullopt; auto rec = (*uu); int n = rec.size(); int d = rec[0].size(); double C = 1; if (md != 998244353) C = 3; double val1 = sqrt(id) * log(id) * C * n * n * d + sqrt(id) * n * n * n * d; double val2 = n * 1.0 * d * 1.0 * id; //debug(val1, val2); if (val1 < val2) { //debug("fastgetvalue"); auto ans = evaluatePrecursionfast(a, rec, id); if (!ans) return nullopt; return (*ans); } else { //debug("sequenceextender"); auto ans = evaluatePrecursion(a, rec, id + 1); if (!ans) return nullopt; return (*ans)[id]; } } vector > transpose(vector > a) { ///transposes the table A if (a.empty()) return a; int n = a.size(); int m = a[0].size(); vector > b(m, vector(n, 0)); for (int i = 0; i < n; ++i) for (int j = 0; j < m; ++j) b[j][i] = a[i][j]; return b; } /// If the size of a is big TL (too many Gauss), if the size of a is small WA (not enough for finding the P-recursion), you should keep balance optional > extendtableslow(vector > a, vector > que) { /// extends table A, finding A(que[i].first,que[i].second), in a O(sum(que[i])) if (que.empty()) { vector ret = {}; return ret; } int ma = 0; for (auto [i,j]: que) { ma = max(ma, j); } int n = a.size(); int m = a[0].size(); vector > ex(n); for (int i = 0; i < n; ++i) { auto h = sequenceextender(a[i], ma + 1); if (!h) { return nullopt; } ex[i] = (*h); } vector res; for (auto [i,j]: que) { vector e; for (int k = 0; k < n; ++k) { e.app(ex[k][j]); } auto h = sequenceextender(e, i + 1); if (!h) { return nullopt; } res.app((*h)[i]); } return res; } optional > extendtablefast(vector > a, vector > que) { /// extends table A, finding A(que[i].first,que[i].second) if (que.empty()) { vector ret = {}; return ret; } int n = a.size(); int m = a[0].size(); vector > ex(n); vector > rec[n]; for (int k = 0; k < n; ++k) { auto uu = findPrecursion(a[k]); if (!uu) { return nullopt; } rec[k] = (*uu); } vector res; for (auto [i,j]: que) { vector e; for (int k = 0; k < n; ++k) { auto uu = evaluatePrecursionfast(a[k], rec[k], j); if (!uu) { return nullopt; } e.app(*uu); } auto h = optimalgetvaluebyid(e, i); if (!h) { return nullopt; } res.app(*h); } return res; } optional > extendtable(vector > a, vector > que) { /// extends table A, finding A(que[i].first,que[i].second) double C = 1; if (md != 998244353) { C = 3; } double op1 = 0; double op2 = 0; for (auto [i,j]: que) { op1 += C * 1.0 * ((int) (a.size())) * 1.0 * sqrt(i + 1) * 1.0 * log(i + 2) * 5 * 5 * 5; op1 += C * sqrt(j + 1) * log(j + 2) * 5 * 5 * 5; } for (auto [i,j]: que) { op2 += C * 1.0 * ((int) (a.size())) * 1.0 * (i + 1) * 1.0 * 5 * 5; op2 += C * 1.0 * (j + 1) * 1.0 * 5 * 5; } //debug(op1, op2); if (op1 < op2) { return extendtablefast(a, que); } else { return extendtableslow(a, que); } } optional > getcolumnoftable(vector > a, int col, int size) { /// get A(0,col),A(1,col),...,A(size-1,col) int n = a.size(); int m = a[0].size(); vector e; for (int i = 0; i < n; ++i) { auto h = sequenceextender(a[i], col + 1); if (!h) { return nullopt; } e.app((*h)[col]); } auto h = sequenceextender(e, size); if (!h) return nullopt; return *h; } optional > getrowoftable(vector > a, int row, int size) { /// get A(row,0),A(row,1),...,A(row,size-1) return getcolumnoftable(transpose(a), row, size); } void test() { vector v1 = {1, 1, 3, 7, 19, 51, 141, 393, 1107, 3139}; ///(1+x+x^2)^n [x^n] auto uu1 = sequenceextender(v1, 30); if (uu1) { auto u1 = (*uu1); debugv(u1); } vector v2 = {1, 30, 465, 4930, 40020, 264306, 1474795, 7133130, 30462615, 116470380}; ///(1+x+x^2)^30 [x^n] auto uu2 = sequenceextender(v2, 70); if (uu2) { auto u2 = (*uu2); debugv(u2); } vector v3 = {0, 567646151, 513265721, 604121291, 715018514, 398975714, 610803800, 499563577, 491416403, 913506524 }; ///s(30,n), not D-finite auto uu3 = sequenceextender(v3, 35); if (uu3) { auto u3 = (*uu3); debugv(u3); } vector v4 = {0, 1, 2, 9, 44, 265, 1854, 14833, 133496, 1334961, 14684570}; ///n!*x*e^(-x) [x^n] (number of permutations of size n without stable points) auto uu4 = sequenceextender(v4, 20); if (uu4) { auto u4 = (*uu4); debugv(u4); } auto uu5 = fastgetvaluebyid({1, 1, 2, 6, 24, 120, 720}, 11); ///factorials ,by id if (uu5) { auto u5 = (*uu5); debug(u5); } auto uu6 = fastgetvaluebyid(v1, 29); ///(1+x+x^2)^n [x^n], by id if (uu6) { auto u6 = (*uu6); debug(u6); } auto uu7 = fastgetvaluebyid(v1, 500000000); ///(1+x+x^2)^n [x^n], by id if (uu7) { auto u7 = (*uu7); debug(u7); } auto uu8 = optimalgetvaluebyid(v1, 500000000); ///(1+x+x^2)^n [x^n], by id if (uu8) { auto u8 = (*uu8); debug(u8); } auto uu9 = optimalgetvaluebyid({1, 1, 2, 6, 24, 120, 720}, 11); ///factorials ,by id if (uu9) { auto u9 = (*uu9); debug(u9); } auto uu10 = optimalgetvaluebyid({1, 1, 2, 6, 24, 120, 720}, 998244352); ///factorials ,by id if (uu10) { auto u10 = (*uu10); debug(u10); } vector > table1 = {{1, 2, 3, 4, 5}, {2, 3, 4, 5, 6}, {3, 4, 5, 6, 7}, {4, 5, 6, 7, 8}, {5, 6, 7, 8, 9}}; ///f(i,j)=i+j+1 vector > que1 = {{1, 1}, {0, 0}, {5, 7}, {11342, 1333}}; auto uu11 = extendtablefast(table1, que1); auto uu12 = extendtableslow(table1, que1); auto uu13 = extendtable(table1, que1); debug((bool) (uu11)); debug((bool) (uu12)); debug((bool) (uu13)); if (uu11 && uu12 && uu13) { auto u11 = (*uu11); auto u12 = (*uu12); auto u13 = (*uu13); debugv(u11); debugv(u12); debugv(u13); } auto uu14 = getcolumnoftable(table1, 10, 100); if (uu14) { auto u14 = (*uu14); debugv(u14); } auto uu15 = getrowoftable(table1, 50, 70); if (uu15) { auto u15 = (*uu15); debugv(u15); } ///exp((x+y)/((1-x)(1-y)), I removed it from examples } int32_t main() { ios_base::sync_with_stdio(false); cin.tie(0); int n,k;cin>>n>>k; k=min(k,n-k); if(n>=md) {n-=md;} if(n