結果

問題 No.3394 Big Binom
コンテスト
ユーザー maksim
提出日時 2025-12-01 20:26:16
言語 C++23
(gcc 13.3.0 + boost 1.89.0)
結果
AC  
実行時間 258 ms / 2,000 ms
コード長 27,880 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 6,700 ms
コンパイル使用メモリ 337,860 KB
実行使用メモリ 19,448 KB
最終ジャッジ日時 2025-12-01 20:26:26
合計ジャッジ時間 9,400 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 21
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

#include <bits/stdc++.h>

using namespace std;
//#define int long long
///You can add it if you want
const int md = 998244353;
mt19937 rnd;
#define app push_back
#define all(x) (x).begin(),(x).end()
#ifdef LOCAL
#define debug(...) [](auto...a){ ((cout << a << ' '), ...) << endl;}(#__VA_ARGS__, ":", __VA_ARGS__)
#define debugv(v) do {cout<< #v <<" : {"; for(int izxc=0;izxc<v.size();++izxc) {cout << v[izxc];if(izxc+1!=v.size()) cout << ","; }cout <<"}"<< endl;} while(0)
#else
#define debug(...)
#define debugv(v)
#endif
#define lob(a,x) lower_bound(all(a),x)
#define upb(a,x) upper_bound(all(a),x)

template<int M, int K, int G>
struct Fft
{
    // 1, 1/4, 1/8, 3/8, 1/16, 5/16, 3/16, 7/16, ...
    int g[1 << (K - 1)];

    Fft() : g()
    { //if tl constexpr...
        // static_assert(K >= 2, "Fft: K >= 2 must hold");
        g[0] = 1;
        g[1 << (K - 2)] = G;
        for (int l = 1 << (K - 2); l >= 2; l >>= 1)
        {
            g[l >> 1] = (g[l] * 1LL * g[l]) % M;
        }
        assert((g[1]*1LL * g[1]) % M == M - 1);
        for (int l = 2; l <= 1 << (K - 2); l <<= 1)
        {
            for (int i = 1; i < l; ++i)
            {
                g[l + i] = (g[l] * 1LL * g[i]) % M;
            }
        }
    }

    void fft(vector<int> &x) const
    {
        const int n = x.size();
        assert(n <= 1 << K);
        for (int h = __builtin_ctz(n); h--;)
        {
            const int l = (1 << h);
            for (int i = 0; i < n >> (h + 1); ++i)
            {
                for (int j = i << (h + 1); j < (((i << 1) + 1) << h); ++j)
                {
                    const int t = (g[i] * 1LL * x[j | l]) % M;
                    x[j | l] = x[j] - t;
                    if (x[j | l] < 0)
                        x[j | l] += M;
                    x[j] += t;
                    if (x[j] >= M)
                        x[j] -= M;
                }
            }
        }
        for (int i = 0, j = 0; i < n; ++i)
        {
            if (i < j)
                std::swap(x[i], x[j]);
            for (int l = n; (l >>= 1) && !((j ^= l) & l);)
            {
            }
        }
    }

    vector<int> convolution(vector<int> a, vector<int> b) const
    {
        if (a.empty() || b.empty())
            return {};
        const int p = M;
        for (int &x: a)
        {
            x %= p;
            if (x >= p)
                x -= p;
            if (x < 0)
                x += p;
        }
        for (int &x: b)
        {
            x %= p;
            if (x >= p)
                x -= p;
            if (x < 0)
                x += p;
        }
        const int na = a.size(), nb = b.size();
        int n, invN = 1;
        for (n = 1; n < na + nb - 1; n <<= 1)
            invN = ((invN & 1) ? (invN + M) : invN) >> 1;
        vector<int> x(n, 0), y(n, 0);
        std::copy(a.begin(), a.end(), x.begin());
        std::copy(b.begin(), b.end(), y.begin());
        fft(x);
        fft(y);
        for (int i = 0; i < n; ++i)
            x[i] = (((static_cast<long long>(x[i]) * y[i]) % M) * invN) % M;
        std::reverse(x.begin() + 1, x.end());
        fft(x);
        x.resize(na + nb - 1);
        return x;
    }
};

Fft<998244353, 21, 31 * 31 * 31 * 31> muls;

template<int32_t MOD>
struct ModInt
{
    int32_t value;

    ModInt() : value(0)
    {
    }

    ModInt(long long v) : value(v % MOD)
    {
        if (value < 0)
            value += MOD;
    }

    ModInt(int32_t v): value(v % MOD)
    {
        if (value < 0)
            value += MOD;
    }

    ModInt operator+=(ModInt m)
    {
        value += m.value;
        if (value >= MOD)
            value -= MOD;
        return value;
    }

    ModInt operator-=(ModInt m)
    {
        value -= m.value;
        if (value < 0)
            value += MOD;
        return value;
    }

    ModInt operator*=(ModInt m)
    {
        value = (value * 1LL * m.value) % MOD;
        return value;
    }

    ModInt power(long long exp) const
    {
        if (exp == 0)
            return 1;
        ModInt res = (exp & 1 ? value : 1);
        ModInt half = power(exp >> 1);
        return res * half * half;
    }

    ModInt operator/=(ModInt m) { return *this *= m.power(MOD - 2); }

    friend std::istream &operator>>(std::istream &is, ModInt &m)
    {
        is >> m.value;
        return is;
    }

    friend std::ostream &operator<<(std::ostream &os, const ModInt &m)
    {
        os << m.value;
        return os;
    }

    explicit operator int32_t() const { return value; }

    explicit operator long long() const { return value; }

    static int32_t mod() { return MOD; }
};

template<int32_t MOD>
ModInt<MOD> operator+(ModInt<MOD> a, ModInt<MOD> b) { return a += b; }

template<int32_t MOD, typename L>
ModInt<MOD> operator+(L a, ModInt<MOD> b) { return ModInt<MOD>(a) += b; }

template<int32_t MOD, typename R>
ModInt<MOD> operator+(ModInt<MOD> a, R b) { return a += b; }

template<int32_t MOD>
ModInt<MOD> operator-(ModInt<MOD> a, ModInt<MOD> b) { return a -= b; }

template<int32_t MOD, typename L>
ModInt<MOD> operator-(L a, ModInt<MOD> b) { return ModInt<MOD>(a) -= b; }

template<int32_t MOD, typename R>
ModInt<MOD> operator-(ModInt<MOD> a, R b) { return a -= b; }

template<int32_t MOD>
ModInt<MOD> operator*(ModInt<MOD> a, ModInt<MOD> b) { return a *= b; }

template<int32_t MOD, typename L>
ModInt<MOD> operator*(L a, ModInt<MOD> b) { return ModInt<MOD>(a) *= b; }

template<int32_t MOD, typename R>
ModInt<MOD> operator*(ModInt<MOD> a, R b) { return a *= b; }

template<int32_t MOD>
ModInt<MOD> operator/(ModInt<MOD> a, ModInt<MOD> b) { return a /= b; }

template<int32_t MOD, typename L>
ModInt<MOD> operator/(L a, ModInt<MOD> b) { return ModInt<MOD>(a) /= b; }

template<int32_t MOD, typename R>
ModInt<MOD> operator/(ModInt<MOD> a, R b) { return a /= b; }

template<int32_t MOD>
bool operator==(ModInt<MOD> a, ModInt<MOD> b) { return a.value == b.value; }

template<int32_t MOD, typename L>
bool operator==(L a, ModInt<MOD> b) { return a == b.value; }

template<int32_t MOD, typename R>
bool operator==(ModInt<MOD> a, R b) { return a.value == b; }

template<int32_t MOD>
bool operator!=(ModInt<MOD> a, ModInt<MOD> b) { return a.value != b.value; }

template<int32_t MOD, typename L>
bool operator!=(L a, ModInt<MOD> b) { return a != b.value; }

template<int32_t MOD, typename R>
bool operator!=(ModInt<MOD> a, R b) { return a.value != b; }

using mint = ModInt<md>;
mint inv(mint x) { return 1 / x; }

__int128 gcd(__int128 a, __int128 b, __int128 &x, __int128 &y)
{
    if (b == 0)
    {
        x = 1;
        y = 0;
        return a;
    }
    __int128 d = gcd(b, a % b, y, x);
    y -= a / b * x;
    return d;
}

__int128 inv(__int128 r, __int128 m)
{
    __int128 x, y;
    gcd(r, m, x, y);
    return (x + m) % m;
}

__int128 crt(__int128 r, __int128 n, __int128 c, __int128 m) { return r + ((c - r) % m + m) * inv(n, m) % m * n; }
const int m2 = 167772161, m3 = 469762049;
Fft<m2, 21, 147771621> muls2;
Fft<m3, 21, 297449090> muls3;

vector<mint> operator*(vector<mint> a, vector<mint> b)
{ ///modulo-dependent convolution
    if (a.empty() || b.empty())
        return {};
    if (md == 998244353)
    {
        vector<int> a1(a.size());
        for (int i = 0; i < a.size(); ++i)
            a1[i] = a[i].value;
        vector<int> b1(b.size());
        for (int i = 0; i < b.size(); ++i)
            b1[i] = b[i].value;
        vector<int> c1 = muls.convolution(a1, b1);
        vector<mint> c;
        for (int x: c1)
            c.app(x);
        return c;
    }
    else
    {
        vector<int> a1(a.size());
        for (int i = 0; i < a.size(); ++i)
            a1[i] = a[i].value;
        vector<int> b1(b.size());
        for (int i = 0; i < b.size(); ++i)
            b1[i] = b[i].value;
        vector<int> c1 = muls.convolution(a1, b1);
        vector<int> c2 = muls2.convolution(a1, b1);
        vector<int> c3 = muls3.convolution(a1, b1);
        assert(c1.size()==c2.size() && c2.size()==c3.size());
        vector<int> c4(c1.size());
        for (int i = 0; i < c1.size(); ++i)
        {
            __int128 ost1 = c1[i];
            __int128 m1 = 998244353;
            __int128 ost2 = c2[i];
            __int128 ost3 = c3[i];
            __int128 ost = crt(crt(ost1, m1, ost2, m2), m1 * 1LL * m2, ost3, m3);
            c4[i] = (ost % md);
        }
        vector<mint> c;
        for (int x: c4)
            c.app(x);
        return c;
    }
}

vector<vector<mint> > gaussbasis(vector<vector<mint> > A) ///returns basis of Av=0
{
    int n = A.size();
    int m = A[0].size();
    int bi = 0;
    for (int i = 0; i < n; ++i)
    {
        if (bi == m)
            break;
        for (int j = i; j < n; ++j)
        {
            if (A[j][bi] != 0)
            {
                if (j != i) { swap(A[i], A[j]); }
                break;
            }
        }
        if (A[i][bi] != 0)
        {
            mint o = 1 / A[i][bi];
            for (int j = i + 1; j < n; ++j)
            {
                mint we = (A[j][bi] * o);
                for (int k = bi; k < m; ++k)
                {
                    A[j][k] -= we * A[i][k];
                }
            }
        }
        else
        {
            ++bi;
            --i;
            continue;
        }
    }
    vector<int> indices(m);
    iota(all(indices), 0);
    for (int i = n - 1; i >= 0; --i)
    {
        int bi = 0;
        while (bi < m && A[i][bi] == 0) { ++bi; }
        if (bi < m)
        {
            indices.erase(find(all(indices), bi));
        }
    }
    vector<vector<mint> > v(indices.size(), vector<mint>(m, 0));
    for (int i = 0; i < indices.size(); ++i)
    {
        v[i][indices[i]] = 1;
    }
    for (int i = n - 1; i >= 0; --i)
    {
        int bi = 0;
        while (bi < m && A[i][bi] == 0) { ++bi; }
        if (bi == m)
            continue;
        for (int k = 0; k < indices.size(); ++k)
        {
            mint cur = 0;
            for (int j = bi + 1; j < m; ++j)
            {
                cur -= A[i][j] * v[k][j];
            }
            v[k][bi] = cur / A[i][bi];
        }
    }
    return v;
}

optional<vector<mint> > gauss(vector<vector<mint> > A, vector<mint> b) ///returns v such that Av=b
{
    int n = A.size();
    assert(b.size()==n);
    int m = A[0].size();
    int bi = 0;
    for (int i = 0; i < n; ++i)
    {
        if (bi == m)
            break;
        for (int j = i; j < n; ++j)
        {
            if (A[j][bi] != 0)
            {
                if (j != i)
                {
                    swap(A[i], A[j]);
                    swap(b[i], b[j]);
                }
                break;
            }
        }
        if (A[i][bi] != 0)
        {
            mint o = inv(A[i][bi]);
            for (int j = i + 1; j < n; ++j)
            {
                mint we = (A[j][bi] * o);
                b[j] -= we * b[i];
                for (int k = bi; k < m; ++k)
                {
                    A[j][k] -= we * A[i][k];
                }
            }
        }
        else
        {
            ++bi;
            --i;
            continue;
        }
    }
    vector<mint> v(m);
    for (int i = n - 1; i >= 0; --i)
    {
        int bi = 0;
        while (bi < m && A[i][bi] == 0) { ++bi; }
        if (bi == m)
        {
            if (b[i] != 0) { return nullopt; }
            else { continue; }
        } {
            mint cur = b[i];
            for (int j = bi + 1; j < m; ++j)
            {
                cur -= A[i][j] * v[j];
            }
            v[bi] = cur * inv(A[i][bi]);
        }
    }
    return v;
}

optional<vector<vector<mint> > > findPrecursion(vector<mint> a)
{ ///finds P-recursion of a given sequence A by gauss
    for (int snd = 0; snd <= 20; ++snd)
    {
        for (int n = 1; n <= snd - 1; ++n)
        {
            vector<vector<mint> > A;
            int d = snd - n;
            int eq = ((int) (a.size())) - (n - 1);
            if (eq < n * d) { continue; }
            for (int i = n - 1; i < a.size(); ++i)
            {
                vector<mint> u;
                for (int j = 0; j < n; ++j)
                {
                    mint de = 1;
                    for (int k = 0; k < d; ++k)
                    {
                        u.app(a[i - j] * de);
                        de *= i;
                    }
                }
                A.app(u);
            }
            vector<vector<mint> > zx = gaussbasis(A);
            if (zx.empty())
                continue;
            //debug(n, d);
            vector<mint> ans = zx[0];
            vector<vector<mint> > res;
            for (int j = 0; j < n; ++j)
            {
                res.app({});
                for (int k = 0; k < d; ++k)
                {
                    res[j].app(ans[j * d + k]);
                }
            }
            return res;
        }
    }
    return nullopt;
}

optional<vector<mint> > evaluatePrecursion(vector<mint> a, vector<vector<mint> > rec, int sz)
{ ///a(0),...,a(a.size()-1) -> (by P-recursion rec) a(0),...,a(sz-1)
    int n = rec.size();
    int d = rec[0].size();
    int given = a.size();
    if (given >= sz)
    {
        a.resize(sz);
        return a;
    }
    if (a.size() < n) { return nullopt; }
    vector<mint> tore;
    for (int i = given; i < sz; ++i)
    {
        mint de = 1;
        mint s = 0;
        for (int k = 0; k < d; ++k)
        {
            s += de * rec[0][k];
            de *= i;
        }
        if (s == 0) { return nullopt; }
        tore.app(s);
    }
    vector<mint> pref(tore.size() + 1);
    pref[0] = 1;
    for (int i = 0; i < tore.size(); ++i) { pref[i + 1] = pref[i] * tore[i]; }
    mint pro = pref[tore.size()];
    mint invpro = 1 / pro;
    mint cur = invpro;
    vector<mint> invtore(tore.size());
    for (int i = tore.size() - 1; i >= 0; --i)
    {
        invtore[i] = cur * pref[i];
        cur *= tore[i];
    }
    for (int i = given; i < sz; ++i)
    {
        mint chi = 0;
        for (int j = 1; j < n; ++j)
        {
            mint de = 1;
            for (int k = 0; k < d; ++k)
            {
                chi += a[i - j] * de * rec[j][k];
                de *= i;
            }
        }
        a.app(((mint) (0)) - chi * invtore[i - given]);
    }
    return a;
}

mint value(vector<mint> a, mint x)
{ ///A(x)
    mint de = 1;
    mint ans = 0;
    for (int i = 0; i < a.size(); ++i)
    {
        ans += a[i] * de;
        de *= x;
    }
    return ans;
}

vector<mint> shiftofsamplingpoints(vector<mint> a)
{ ///P(0),...,P(t) we want to compute P(0),...,P(4t+1)
    int t = a.size() - 1;
    vector<mint> fact(4 * t + 2);
    fact[0] = 1;
    for (int i = 1; i < 4 * t + 2; ++i)
        fact[i] = fact[i - 1] * i;
    vector<mint> invf(4 * t + 2);
    invf[4 * t + 1] = 1 / fact[4 * t + 1];
    for (int i = 4 * t; i >= 0; --i) { invf[i] = (invf[i + 1] * (i + 1)); }
    assert(invf[0]==1);
    vector<mint> invm(4 * t + 2, 0);
    for (int i = 1; i < 4 * t + 2; ++i) { invm[i] = fact[i - 1] * invf[i]; }
    vector<mint> values(t + 1, 0);
    for (int k = 0; k <= t; ++k)
    {
        mint o = 1;
        if ((t - k) % 2 == 1) { o = (((mint) (0)) - 1); }
        values[k] = (a[k] * invf[k] * invf[t - k] * o);
    }
    vector<mint> h = invm * values;
    vector<mint> res;
    for (int i = 0; i <= t; ++i)
    {
        res.app(a[i]);
    }
    for (int x = t + 1; x <= 4 * t + 1; ++x)
    {
        mint ans = fact[x];
        ans *= invf[x - t - 1];
        ans *= h[x];
        res.app(ans);
    }
    return res;
}

optional<mint> evaluatePrecursionfast(vector<mint> a, vector<vector<mint> > rec, int id)
{ ///a(0),...,a(a.size()-1) -> (by P-recursion rec) a(id), O(sqrt(id)*log(id))
    if (id < a.size())
        return a[id];
    int n = rec.size();
    int d = 1;
    for (auto &v: rec) { d = max(d, ((int) (v.size() - 1))); }
    if (a.size() < n - 1) { return nullopt; }
    if (n == 1) { return 0; }
    int l = n - 1;
    int shift = 0;
    while (a.size() > l)
    {
        a.erase(a.begin());
        ++shift;
        --id;
    }
    int u = 0;
    while ((1LL << u) * (1LL << u) <= id) { ++u; }
    int B = (1 << u);
    vector<mint> S;
    vector<vector<vector<mint> > > A(l, vector<vector<mint> >(l));
    int sz = d;
    for (int k = 0; k <= d; ++k)
    {
        S.app(value(rec[0], k + l + shift));
    }
    for (int i = 0; i < l - 1; ++i)
    {
        for (int j = 0; j < l; ++j)
        {
            if (j == i + 1)
            {
                for (int k = 0; k <= d; ++k)
                {
                    A[i][j].app(value(rec[0], k + l + shift));
                }
            }
            else
            {
                for (int k = 0; k <= d; ++k)
                {
                    A[i][j].app(0);
                }
            }
        }
    }
    for (int j = 0; j < l; ++j)
    {
        for (int k = 0; k <= d; ++k)
        {
            A[l - 1][j].app(((mint) (0)) - value(rec[l - j], k + l + shift));
        }
    }
    for (int s = 0; s < u; ++s)
    {
        S = shiftofsamplingpoints(S);
        assert(S.size()==4*sz+2);
        for (int i = 0; i < l; ++i)
        {
            for (int j = 0; j < l; ++j)
            {
                A[i][j] = shiftofsamplingpoints(A[i][j]);
                assert(A[i][j].size()==4*sz+2);
            }
        }
        vector<vector<vector<mint> > > newA(l, vector<vector<mint> >(l, vector<mint>(2 * sz + 1, 0)));
        for (int k = 0; k <= 2 * sz; ++k)
        {
            for (int ii = 0; ii < l; ++ii)
            {
                for (int jj = 0; jj < l; ++jj)
                {
                    for (int kk = 0; kk < l; ++kk)
                    {
                        newA[ii][kk][k] += A[ii][jj][2 * k + 1] * A[jj][kk][2 * k];
                    }
                }
            }
        }
        vector<mint> newS(2 * sz + 1, 0);
        for (int k = 0; k <= 2 * sz; ++k) { newS[k] = S[2 * k] * S[2 * k + 1]; }
        sz *= 2;
        S = newS;
        A = newA;
    }
    int k = (id - l) / B;
    assert(k>=0); ///id>=l+1 here
    mint pro = 1;
    vector<mint> v;
    for (int i = 0; i < l; ++i)
        v.app(a[i]);
    for (int i = 0; i < k; ++i)
    {
        vector<mint> newv(l, 0);
        vector<vector<mint> > M(l, vector<mint>(l, 0));
        for (int ii = 0; ii < l; ++ii)
        {
            for (int jj = 0; jj < l; ++jj)
            {
                assert(i<A[ii][jj].size());
                M[ii][jj] = A[ii][jj][i];
            }
        }
        pro *= S[i];
        for (int ii = 0; ii < l; ++ii)
        {
            for (int jj = 0; jj < l; ++jj)
            {
                newv[ii] += M[ii][jj] * v[jj];
            }
        }
        v = newv;
    }
    int cur = k * B + l - 1;
    assert(cur<id);
    while (cur < id)
    {
        mint newval = 0;
        for (int i = 0; i < l; ++i)
        {
            newval -= v[i] * value(rec[l - i], cur + 1 + shift);
        }
        mint de = value(rec[0], cur + 1 + shift);
        pro *= de;
        for (mint &x: v) { x *= de; }
        v.erase(v.begin());
        v.app(newval);
        ++cur;
    }
    if (pro == 0)
        return nullopt;
    return v.back() / pro;
}

optional<vector<mint> > sequenceextender(vector<mint> a, int sz)
{ ///finds P-recursion, and if was found, calculates a(0),...,a(sz-1)
    if (a.size() >= sz)
    {
        a.resize(sz);
        return a;
    }
    auto uu = findPrecursion(a);
    if (!uu)
        return nullopt;
    auto rec = (*uu);
    auto ans = evaluatePrecursion(a, rec, sz);
    if (!ans)
        return nullopt;
    return (*ans);
}

optional<mint> fastgetvaluebyid(vector<mint> a, int id)
{ ///finds P-recursion, and if was found, calculates a(id) in O(sqrt(id)*log(id))
    if (a.size() > id) { return a[id]; }
    auto uu = findPrecursion(a);
    if (!uu)
        return nullopt;
    auto rec = (*uu);
    auto ans = evaluatePrecursionfast(a, rec, id);
    if (!ans)
        return nullopt;
    return (*ans);
}

optional<mint> optimalgetvaluebyid(vector<mint> a, int id)
{ ///finds P-recursion, and if was found, calculates a(id) by choosing optimal of O(id) method and O(sqrt(id)*log(id)) method
    if (a.size() > id) { return a[id]; }
    auto uu = findPrecursion(a);
    if (!uu)
        return nullopt;
    auto rec = (*uu);
    int n = rec.size();
    int d = rec[0].size();
    double C = 1;
    if (md != 998244353)
        C = 3;
    double val1 = sqrt(id) * log(id) * C * n * n * d + sqrt(id) * n * n * n * d;
    double val2 = n * 1.0 * d * 1.0 * id;
    //debug(val1, val2);
    if (val1 < val2)
    {
        //debug("fastgetvalue");
        auto ans = evaluatePrecursionfast(a, rec, id);
        if (!ans)
            return nullopt;
        return (*ans);
    }
    else
    {
        //debug("sequenceextender");
        auto ans = evaluatePrecursion(a, rec, id + 1);
        if (!ans)
            return nullopt;
        return (*ans)[id];
    }
}

vector<vector<mint> > transpose(vector<vector<mint> > a)
{ ///transposes the table A
    if (a.empty())
        return a;
    int n = a.size();
    int m = a[0].size();
    vector<vector<mint> > b(m, vector<mint>(n, 0));
    for (int i = 0; i < n; ++i)
        for (int j = 0; j < m; ++j)
            b[j][i] = a[i][j];
    return b;
}

/// If the size of a is big TL (too many Gauss), if the size of a is small WA (not enough for finding the P-recursion), you should keep balance
optional<vector<mint> > extendtableslow(vector<vector<mint> > a, vector<pair<int, int> > que)
{ /// extends table A, finding A(que[i].first,que[i].second), in a O(sum(que[i]))
    if (que.empty())
    {
        vector<mint> ret = {};
        return ret;
    }
    int ma = 0;
    for (auto [i,j]: que) { ma = max(ma, j); }
    int n = a.size();
    int m = a[0].size();
    vector<vector<mint> > ex(n);
    for (int i = 0; i < n; ++i)
    {
        auto h = sequenceextender(a[i], ma + 1);
        if (!h) { return nullopt; }
        ex[i] = (*h);
    }
    vector<mint> res;
    for (auto [i,j]: que)
    {
        vector<mint> e;
        for (int k = 0; k < n; ++k)
        {
            e.app(ex[k][j]);
        }
        auto h = sequenceextender(e, i + 1);
        if (!h) { return nullopt; }
        res.app((*h)[i]);
    }
    return res;
}

optional<vector<mint> > extendtablefast(vector<vector<mint> > a, vector<pair<int, int> > que)
{ /// extends table A, finding A(que[i].first,que[i].second)
    if (que.empty())
    {
        vector<mint> ret = {};
        return ret;
    }
    int n = a.size();
    int m = a[0].size();
    vector<vector<mint> > ex(n);
    vector<vector<mint> > rec[n];
    for (int k = 0; k < n; ++k)
    {
        auto uu = findPrecursion(a[k]);
        if (!uu) { return nullopt; }
        rec[k] = (*uu);
    }
    vector<mint> res;
    for (auto [i,j]: que)
    {
        vector<mint> e;
        for (int k = 0; k < n; ++k)
        {
            auto uu = evaluatePrecursionfast(a[k], rec[k], j);
            if (!uu) { return nullopt; }
            e.app(*uu);
        }
        auto h = optimalgetvaluebyid(e, i);
        if (!h) { return nullopt; }
        res.app(*h);
    }
    return res;
}

optional<vector<mint> > extendtable(vector<vector<mint> > a, vector<pair<int, int> > que)
{ /// extends table A, finding A(que[i].first,que[i].second)
    double C = 1;
    if (md != 998244353) { C = 3; }
    double op1 = 0;
    double op2 = 0;
    for (auto [i,j]: que)
    {
        op1 += C * 1.0 * ((int) (a.size())) * 1.0 * sqrt(i + 1) * 1.0 * log(i + 2) * 5 * 5 * 5;
        op1 += C * sqrt(j + 1) * log(j + 2) * 5 * 5 * 5;
    }
    for (auto [i,j]: que)
    {
        op2 += C * 1.0 * ((int) (a.size())) * 1.0 * (i + 1) * 1.0 * 5 * 5;
        op2 += C * 1.0 * (j + 1) * 1.0 * 5 * 5;
    }
    //debug(op1, op2);
    if (op1 < op2) { return extendtablefast(a, que); }
    else { return extendtableslow(a, que); }
}

optional<vector<mint> > getcolumnoftable(vector<vector<mint> > a, int col, int size)
{ /// get A(0,col),A(1,col),...,A(size-1,col)
    int n = a.size();
    int m = a[0].size();
    vector<mint> e;
    for (int i = 0; i < n; ++i)
    {
        auto h = sequenceextender(a[i], col + 1);
        if (!h) { return nullopt; }
        e.app((*h)[col]);
    }
    auto h = sequenceextender(e, size);
    if (!h)
        return nullopt;
    return *h;
}

optional<vector<mint> > getrowoftable(vector<vector<mint> > a, int row, int size)
{ /// get A(row,0),A(row,1),...,A(row,size-1)
    return getcolumnoftable(transpose(a), row, size);
}

void test()
{
    vector<mint> v1 = {1, 1, 3, 7, 19, 51, 141, 393, 1107, 3139}; ///(1+x+x^2)^n [x^n]
    auto uu1 = sequenceextender(v1, 30);
    if (uu1)
    {
        auto u1 = (*uu1);
        debugv(u1);
    }
    vector<mint> v2 = {1, 30, 465, 4930, 40020, 264306, 1474795, 7133130, 30462615, 116470380}; ///(1+x+x^2)^30 [x^n]
    auto uu2 = sequenceextender(v2, 70);
    if (uu2)
    {
        auto u2 = (*uu2);
        debugv(u2);
    }
    vector<mint> v3 = {0, 567646151, 513265721, 604121291, 715018514, 398975714, 610803800, 499563577, 491416403,
        913506524
    }; ///s(30,n), not D-finite
    auto uu3 = sequenceextender(v3, 35);
    if (uu3)
    {
        auto u3 = (*uu3);
        debugv(u3);
    }
    vector<mint> v4 = {0, 1, 2, 9, 44, 265, 1854, 14833, 133496, 1334961, 14684570};
    ///n!*x*e^(-x) [x^n] (number of permutations of size n without stable points)
    auto uu4 = sequenceextender(v4, 20);
    if (uu4)
    {
        auto u4 = (*uu4);
        debugv(u4);
    }
    auto uu5 = fastgetvaluebyid({1, 1, 2, 6, 24, 120, 720}, 11); ///factorials ,by id
    if (uu5)
    {
        auto u5 = (*uu5);
        debug(u5);
    }
    auto uu6 = fastgetvaluebyid(v1, 29); ///(1+x+x^2)^n [x^n], by id
    if (uu6)
    {
        auto u6 = (*uu6);
        debug(u6);
    }
    auto uu7 = fastgetvaluebyid(v1, 500000000); ///(1+x+x^2)^n [x^n], by id
    if (uu7)
    {
        auto u7 = (*uu7);
        debug(u7);
    }
    auto uu8 = optimalgetvaluebyid(v1, 500000000); ///(1+x+x^2)^n [x^n], by id
    if (uu8)
    {
        auto u8 = (*uu8);
        debug(u8);
    }
    auto uu9 = optimalgetvaluebyid({1, 1, 2, 6, 24, 120, 720}, 11); ///factorials ,by id
    if (uu9)
    {
        auto u9 = (*uu9);
        debug(u9);
    }
    auto uu10 = optimalgetvaluebyid({1, 1, 2, 6, 24, 120, 720}, 998244352); ///factorials ,by id
    if (uu10)
    {
        auto u10 = (*uu10);
        debug(u10);
    }
    vector<vector<mint> > table1 =
            {{1, 2, 3, 4, 5}, {2, 3, 4, 5, 6}, {3, 4, 5, 6, 7}, {4, 5, 6, 7, 8}, {5, 6, 7, 8, 9}}; ///f(i,j)=i+j+1
    vector<pair<int, int> > que1 = {{1, 1}, {0, 0}, {5, 7}, {11342, 1333}};
    auto uu11 = extendtablefast(table1, que1);
    auto uu12 = extendtableslow(table1, que1);
    auto uu13 = extendtable(table1, que1);
    debug((bool) (uu11));
    debug((bool) (uu12));
    debug((bool) (uu13));
    if (uu11 && uu12 && uu13)
    {
        auto u11 = (*uu11);
        auto u12 = (*uu12);
        auto u13 = (*uu13);
        debugv(u11);
        debugv(u12);
        debugv(u13);
    }
    auto uu14 = getcolumnoftable(table1, 10, 100);
    if (uu14)
    {
        auto u14 = (*uu14);
        debugv(u14);
    }
    auto uu15 = getrowoftable(table1, 50, 70);
    if (uu15)
    {
        auto u15 = (*uu15);
        debugv(u15);
    }
    ///exp((x+y)/((1-x)(1-y)), I removed it from examples
}

int32_t main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    int n,k;cin>>n>>k;
    k=min(k,n-k);
    if(n>=md) {n-=md;}
    if(n<k) {
        cout<<0;
        return 0;
    }
    auto uu5 = fastgetvaluebyid({1, 1, 2, 6, 24, 120, 720}, n); ///factorials ,by id
    auto uu6 = fastgetvaluebyid({1, 1, 2, 6, 24, 120, 720}, k); ///factorials ,by id
    auto uu7 = fastgetvaluebyid({1, 1, 2, 6, 24, 120, 720}, n-k); ///factorials ,by id
    mint ans=((*uu5)/(*uu6))/(*uu7);
    cout<<ans;
    return 0;
}
0