結果

問題 No.1100 Boxes
ユーザー 🍮かんプリン🍮かんプリン
提出日時 2021-10-26 20:15:47
言語 C++11
(gcc 11.4.0)
結果
AC  
実行時間 64 ms / 2,000 ms
コード長 23,331 bytes
コンパイル時間 2,118 ms
コンパイル使用メモリ 174,604 KB
実行使用メモリ 6,104 KB
最終ジャッジ日時 2024-10-05 22:14:06
合計ジャッジ時間 3,300 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
5,248 KB
testcase_01 AC 1 ms
5,248 KB
testcase_02 AC 1 ms
5,248 KB
testcase_03 AC 1 ms
5,248 KB
testcase_04 AC 1 ms
5,248 KB
testcase_05 AC 1 ms
5,248 KB
testcase_06 AC 1 ms
5,248 KB
testcase_07 AC 2 ms
5,248 KB
testcase_08 AC 1 ms
5,248 KB
testcase_09 AC 1 ms
5,248 KB
testcase_10 AC 1 ms
5,248 KB
testcase_11 AC 1 ms
5,248 KB
testcase_12 AC 2 ms
5,248 KB
testcase_13 AC 2 ms
5,248 KB
testcase_14 AC 1 ms
5,248 KB
testcase_15 AC 1 ms
5,248 KB
testcase_16 AC 2 ms
5,248 KB
testcase_17 AC 2 ms
5,248 KB
testcase_18 AC 2 ms
5,248 KB
testcase_19 AC 3 ms
5,248 KB
testcase_20 AC 8 ms
5,248 KB
testcase_21 AC 26 ms
5,248 KB
testcase_22 AC 50 ms
5,936 KB
testcase_23 AC 24 ms
5,248 KB
testcase_24 AC 30 ms
5,248 KB
testcase_25 AC 33 ms
5,248 KB
testcase_26 AC 61 ms
6,072 KB
testcase_27 AC 51 ms
5,916 KB
testcase_28 AC 14 ms
5,248 KB
testcase_29 AC 55 ms
5,852 KB
testcase_30 AC 48 ms
5,848 KB
testcase_31 AC 20 ms
5,248 KB
testcase_32 AC 36 ms
5,248 KB
testcase_33 AC 64 ms
6,104 KB
testcase_34 AC 63 ms
5,980 KB
testcase_35 AC 2 ms
5,248 KB
testcase_36 AC 55 ms
6,104 KB
testcase_37 AC 1 ms
5,248 KB
testcase_38 AC 29 ms
5,248 KB
testcase_39 AC 60 ms
6,080 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

/**
 *   @FileName	a.cpp
 *   @Author	kanpurin
 *   @Created	2021.10.26 20:15:40
**/

#include "bits/stdc++.h" 
using namespace std; 
typedef long long ll;


template< int MOD >
struct mint {
public:
    unsigned int x;
    mint() : x(0) {}
    mint(long long v) {
        long long w = (long long)(v % (long long)(MOD));
        if (w < 0) w += MOD;
        x = (unsigned int)(w);
    }
    mint(std::string &s) {
        unsigned int z = 0;
        for (int i = 0; i < s.size(); i++) {
            z *= 10;
            z += s[i] - '0';
            z %= MOD;
        }
        x = z;
    }
    mint operator+() const { return *this; }
    mint operator-() const { return mint() - *this; }
    mint& operator+=(const mint &a) {
        if ((x += a.x) >= MOD) x -= MOD;
        return *this;
    }
    mint& operator-=(const mint &a) {
        if ((x -= a.x) >= MOD) x += MOD;
        return *this;
    }
    mint& operator*=(const mint &a) {
        unsigned long long z = x;
        z *= a.x;
        x = (unsigned int)(z % MOD);
        return *this;
    }
    mint& operator/=(const mint &a) {return *this = *this * a.inv(); }
    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const mint& lhs, const mint& rhs) {
        return lhs.x == rhs.x;
    }
    friend bool operator!=(const mint& lhs, const mint& rhs) {
        return lhs.x != rhs.x;
    }
    friend std::ostream& operator<<(std::ostream &os, const mint &n) {
        return os << n.x;
    }
    friend std::istream &operator>>(std::istream &is, mint &n) {
        unsigned int x;
        is >> x;
        n = mint(x);
        return is;
    }
    mint inv() const {
        assert(x);
        return pow(MOD-2);
    }
    mint pow(long long n) const {        
        assert(0 <= n);
        mint p = *this, r = 1;
        while (n) {
            if (n & 1) r *= p;
            p *= p;
            n >>= 1;
        }
        return r;
    }
    
    mint sqrt() const {
        if (this->x < 2) return *this;
        if (this->pow((MOD-1)>>1).x != 1) return mint(0);
        mint b = 1, one = 1;
        while (b.pow((MOD-1) >> 1) == 1) b += one;
        long long m = MOD-1, e = 0;
        while (m % 2 == 0) m >>= 1, e += 1;
        mint x = this->pow((m - 1) >> 1);
        mint y = (*this) * x * x;
        x *= (*this);
        mint z = b.pow(m);
        while (y.x != 1) {
            int j = 0;
            mint t = y;
            while (t != one) j += 1, t *= t;
            z = z.pow(1LL << (e-j-1));
            x *= z; z *= z; y *= z; e = j;
        }
        return x;
    }
};

constexpr int MOD = 998244353;

template < const int MOD , bool any = false>
struct FormalPowerSeries {
private:
    using FPS = FormalPowerSeries<MOD,any>;

    void ntt(bool inverse) {
        static bool first = true;
        static mint<MOD> dw[30], idw[30];
        if (first) {
            first = false;
            mint<MOD> root = 2;
            while (root.pow((MOD - 1) / 2) == 1) root += 1;
            for (int i = 0; i < 30; i++) dw[i] = -root.pow((MOD - 1) >> (i + 2)), idw[i] = mint<MOD>(1) / dw[i];
        }
        int n = this->size();
        assert((n & (n - 1)) == 0);
        if (not inverse) {
            for (int m = n; m >>= 1;) {
                mint<MOD> w = 1;
                for (int s = 0, k = 0; s < n; s += 2 * m) {
                    for (int i = s, j = s + m; i < s + m; i++, j++) {
                        auto x = this->a[i], y = this->a[j]*w;
                        if (x.x >= MOD) x.x -= MOD;
                        this->a[i].x = x.x + y.x, this->a[j].x = x.x+(MOD-y.x);
                    }
                    w *= dw[__builtin_ctz(++k)];
                }
            }
        } else {
            for (int m = 1; m < n; m *= 2) {
                mint<MOD> w = 1;
                for (int s = 0, k = 0; s < n; s += 2 * m) {
                    for (int i = s, j = s + m; i < s + m; i++, j++) {
                        auto x = this->a[i], y = this->a[j];
                        this->a[i] = x+y, this->a[j].x = x.x+(MOD-y.x), this->a[j] *= w;
                    }
                    w *= idw[__builtin_ctz(++k)];
                }
            }
        }
        auto c = mint<MOD>(1) / mint<MOD>(inverse ? n : 1);
        for (auto&& e : this->a) e *= c;
    }
    
    FPS convolution_naive(FPS &a, FPS &b) const {
        int n = int(a.size()), m = int(b.size());
        FPS ans(n+m-1);
        if (n < m) {
            for (int j = 0; j < m; j++) {
                for (int i = 0; i < n; i++) ans[i + j] += a[i]*b[j];
            }
        } 
        else {
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < m; j++) ans[i + j] += a[i]*b[j];
            }
        }
        return ans;
    }

    FPS& convolution_inplace(FPS b) {
        if (this->size() == 0 || b.size() == 0) {
            this->a.clear();
            return *this;
        }
        if (!any) {
            int n = this->size(), m = b.size(), sz = 1 << __lg(2*(n+m-1)-1);
            if (min(n, m) <= 60) return *this = convolution_naive(*this,b);
            this->resize(sz), this->ntt(false);
            b.resize(sz), b.ntt(false);
            for (int i = 0; i < sz; i++) this->a[i] *= b[i];
            this->ntt(true), this->resize(n + m - 1);
            return *this;
        }
        else {
            int n = this->a.size(), m = b.a.size();
            static constexpr int mod0 = 998244353, mod1 = 1300234241, mod2 = 1484783617;
            FormalPowerSeries< mod0 > l0(n), r0(m);
            FormalPowerSeries< mod1 > l1(n), r1(m);
            FormalPowerSeries< mod2 > l2(n), r2(m);
            for (int i = 0; i < n; i++) l0.a[i] = this->a[i].x, l1.a[i] = this->a[i].x, l2.a[i] = this->a[i].x;
            for (int j = 0; j < m; j++) r0.a[j] = b.a[j].x, r1.a[j] = b.a[j].x, r2.a[j] = b.a[j].x;
            l0 *= r0;
            l1 *= r1;
            l2 *= r2;
            crt(*this,l0,l1,l2);
            return *this;
        }
    }

    template<const int MOD0, const int MOD1, const int MOD2>
    static void crt(FPS &fps,
             const FormalPowerSeries<MOD0> &fps0, 
             const FormalPowerSeries<MOD1> &fps1, 
             const FormalPowerSeries<MOD2> &fps2) {
        assert(fps0.size() == fps1.size() && fps0.size() == fps2.size());
        int n = (int)fps0.size();
        fps.resize(n);
        static const mint<MOD1> im0 = mint<MOD1>(MOD0).inv();
        static const mint<MOD2> im1 = mint<MOD2>(MOD1).inv(), im0m1 = im1/MOD0;
        static const mint<MOD> m0 = MOD0, m0m1 = m0*MOD1;
        for (int i = 0; i < n; i++) {
            int y0 = fps0.a[i].x;
            int y1 = (im0*(fps1.a[i]-y0)).x;
            int y2 = (im0m1*(fps2.a[i]-y0)-im1*y1).x;
            fps.a[i] = m0m1*y2+y0+m0*y1;
        }
    }

    struct Fact {
    private:
        int N;
    public:
        vector< mint< MOD > > FACT, IFACT;
        Fact(int n) : N(n) {
            FACT.resize(n + 1);
            IFACT.resize(n + 1);
            FACT[0] = 1;
            for (int i = 1; i <= n; i++) {
                FACT[i] = FACT[i - 1] * i;
            }
            IFACT[n] = FACT[n].inv();
            for (int i = n-1; i >= 0; i--) {
                IFACT[i] = IFACT[i+1] * (i+1);
            }
        }
    };

    FPS rev() const {
        FPS ret(*this);
        reverse(ret.a.begin(), ret.a.end());
        return ret;
    }

    void shrink() {
        while (this->a.size() && this->a.back() == 0) this->a.pop_back();
    }

    static vector<FPS> subproduct_tree(const vector<mint<MOD>> &xs) {
        int n = (int) xs.size();
        int k = 1;
        while(k < n) k <<= 1;
        vector<FPS> g(2 * k, {1});
        for(int i = 0; i < n; i++) g[k + i] = {-xs[i], 1};
        for(int i = k; i-- > 1;) g[i] = g[i << 1] * g[i << 1 | 1];
        return g;
    }

    FPS _sqrt(int s) const {
		assert(this->a[0]==1);
		static const mint<MOD> half=mint<MOD>(1)/2;
		FPS f({1}),g({1}),z({1});
		for(int n=1;n<s;n*=2){
			for (int i = 0; i < n; i++) z[i]*=z[i];
			z.ntt(true);
            
			FPS delta(2*n),gbuf(2*n);
			for (int i = 0; i<n; i++) delta[n+i] = z[i] - (i<size()?this->a[i]:0) - (n+i<size()?this->a[n+i]:0);
            copy(g.a.begin(),g.a.end(), gbuf.a.begin());
			delta.ntt(false);
			gbuf.ntt(false);
			for (int i = 0;i < 2*n; i++) delta[i]*=gbuf[i];
			delta.ntt(true);

			f.resize(2*n);
            for(int i=n;i<2*n;i++) f[i]=-half*delta[i];
			
			if(2*n>=s)break;
			
			z=f;
			z.ntt(false);
			
			FPS eps=gbuf;
			for (int i = 0;i < 2*n;i++) eps[i]*=z[i];
			eps.ntt(true);
			
			for(int i = 0; i < n; i++)eps[i]=0;
			eps.ntt(false);
			
			for(int i = 0; i < 2*n; i++)eps[i]*=gbuf[i];
			eps.ntt(true);
			g.resize(2*n);
			for(int i = n; i < 2*n; i++)g[i]=-eps[i];
		}
		f.resize(s);
		return f;
	}

public:
    vector<mint<MOD>> a;

    FormalPowerSeries(int sz = 0) {
        this->a.resize(sz, 0);
    }

    FormalPowerSeries(const std::initializer_list<mint<MOD>> v) {
        this->a = v;
    }

    FormalPowerSeries(const std::vector<mint<MOD>> &v) {
        this->a = v;
    }

    int size() const {
        return a.size();
    }

    void resize(size_t sz, mint<MOD> m = mint<MOD>(0)) {
        this->a.resize(sz,m);
    }

    FPS operator+(const mint<MOD> &a) const { return FPS(*this) += a; }
    FPS operator+(const FPS &a) const { return FPS(*this) += a; }
    FPS operator-(const mint<MOD> &a) const { return FPS(*this) -= a; }
    FPS operator-(const FPS &a) const { return FPS(*this) -= a; }
    FPS operator*(const mint<MOD> &a) const { return FPS(*this) *= a; }
    FPS operator*(const long long a) const { return FPS(*this) *= a; }
    FPS operator*(const FPS &a) const { return FPS(*this) *= a; }
	FPS operator/(const mint<MOD> &a) const { return FPS(*this) /= a;}
    FPS operator/(const FPS &a) const { return FPS(*this) /= a; }
    FPS operator%(const FPS &a) const { return FPS(*this) %= a; }
    FPS &operator+=(const mint<MOD> &v) {
        this->a[0] += v;
        return *this;
    }
    FPS &operator+=(const FPS &r) {
        this->resize(max((int)this->size(),r.size()));
        for(int i = 0; i < (int)r.size(); i++) this->a[i] += r.a[i];
        return *this;
    }
    FPS &operator-=(const mint<MOD> &v) {
        this->a[0] -= v;
        return *this;
    }
    FPS &operator-=(const FPS &r) {
        this->resize(max((int)this->size(),r.size()));
        for(int i = 0; i < (int)r.size(); i++) this->a[i] -= r.a[i];
        return *this;
    }
    FPS &operator*=(const mint<MOD> &v) {
        for (int i = 0; i < this->size(); i++) this->a[i] *= v;
        return *this;
    }
    FPS &operator*=(const long long v) {
        for (int i = 0; i < this->size(); i++) this->a[i] *= v;
        return *this;
    }
    FPS &operator*=(const FPS &r) {
        this->convolution_inplace(r);
        return *this;
    }
	FPS &operator/=(const mint<MOD> &v){
		return *this *= v.inv();
	}
    FPS &operator/=(const FPS &r) {
        if (this->size() < r.size()) {
            this->a.clear();
            return *this;
        }
        int n = this->size() - r.size() + 1;
        if ((int)r.size() <= 64) {
            FPS f(*this), g(r);
            g.shrink();
            mint<MOD> coeff = g.a.back().inv();
            for (auto &x : g.a) x *= coeff;
            int deg = (int)f.size() - (int)g.size() + 1;
            int gs = g.size();
            FPS quo(deg);
            for (int i = deg - 1; i >= 0; i--) {
                quo[i] = f[i + gs - 1];
                for (int j = 0; j < gs; j++) f[i + j] -= quo[i] * g[j];
            }
            *this = quo * coeff;
            this->resize(n, 0);
            return *this;
        }
        return *this = ((*this).rev().low(n) * r.rev().inverse(n)).low(n).rev();
    }

    FPS &operator%=(const FPS &Q) {
        if(Q.size() > this->size()) return *this;
        if(Q.size() < 32) {
            int dQ = Q.size()-1;
            while(dQ && Q.a[dQ] == 0) dQ--;
            assert(Q.a[dQ] != 0);
            for(int i = this->size()-1; i >= dQ; i--){
                if(this->a[i] == 0) continue;
                mint<MOD> x = this->a[i] / Q.a[dQ];
                this->a[i] = 0;
                for(int j = 1; j <= dQ; j++){
                    this->a[i - j] -= x * Q.a[dQ - j];
                }
            }
            shrink();
            return *this;
        }
        FPS P = (*this) / Q;
        P *= Q;
        int dR = -1;
        for(int i = 0; i < Q.size()-1; i++){
            P.a[i] = this->a[i] - P.a[i];
            if(P.a[i] != 0) dR = i;
        }
        this->a.resize(dR + 1);
        for(int i = 0; i <= dR; i++) this->a[i] = P.a[i];
        return *this;
    }


    FPS low(int s) const {
        return FPS(vector<mint<MOD>>(this->a.begin(),this->a.begin()+min(max(s,1),this->size())));
    }

    FPS inverse(int deg = -1) const {
        int n = this->size();
        assert(n != 0 && this->a[0].x != 0);
        if(deg == -1) deg = n;

        if (!any) {
            FPS r({this->a[0].inv()});
            for(int m=1;m<deg;m*=2) {
                FPS f(vector<mint<MOD>>(this->a.begin(), this->a.begin() + min(n, 2*m)));
                FPS g(r);
                f.resize(2*m), f.ntt(false);
                g.resize(2*m), g.ntt(false);
                for (int i = 0; i < 2*m; i++) f[i] *= g[i];
                f.ntt(true);
                f.a.erase(f.a.begin(), f.a.begin() + m);
                f.resize(2*m), f.ntt(false);
                for (int i = 0; i < 2*m; i++) f[i] *= g[i];
                f.ntt(true); 
                for (int i = 0; i < 2*m; i++) f[i] = -f[i];
                r.a.insert(r.a.end(), f.a.begin(), f.a.begin() + m);
            }
            return r.low(deg);
        }
        else {

            FPS r({this->a[0].inv()});
            for (int i = 1; i < deg; i <<= 1)
                r = (r*2 - r.square()*(*this).low(i<<1)).low(i<<1);
            return r.low(deg);
        }
    }

    FPS& square_inplace() {
        if (this->size() == 0) {
            return *this;
        }
        if (!any) {
            int n = this->size(), sz = 1 << __lg(2*(n+n-1)-1);
            if (n <= 60) return *this = convolution_naive(*this,*this);
            this->resize(sz), this->ntt(false);
            for (int i = 0; i < sz; i++) this->a[i] *= this->a[i];
            this->ntt(true), this->resize(n+n-1);
            return *this;
        }
        else {
            int n = this->a.size();
            static constexpr int mod0 = 998244353, mod1 = 1300234241, mod2 = 1484783617;
            FormalPowerSeries< mod0 > f0(n);
            FormalPowerSeries< mod1 > f1(n);
            FormalPowerSeries< mod2 > f2(n);
            for (int i = 0; i < n; i++) f0.a[i] = this->a[i].x, f1.a[i] = this->a[i].x, f2.a[i] = this->a[i].x;
            f0.square_inplace();
            f1.square_inplace();
            f2.square_inplace();
            crt(*this,f0,f1,f2);
            return *this;
        }
    }
    FPS square() const { return FPS(*this).square_inplace(); }

    FPS& differential_inplace() {
        const int n = (int)this->a.size();
        assert(n > 0);
        for(int i = 1; i < n; i++) this->a[i-1] = this->a[i] * i;
        this->a[n-1] = 0;
        return *this;
    }
    FPS differential() const { return FPS(*this).differential_inplace(); }

    FPS& integral_inplace() {
        const int n = (int)this->a.size();
        assert(n > 0);
        this->a.insert(this->a.begin(),0);
        vector<mint<MOD>> inv(n+1); inv[1] = 1;
        for (int i = 2; i <= n; i++) inv[i] = -inv[MOD%i]*(MOD/i);
        for (int i = 2; i <= n; i++) this->a[i] *= inv[i];
        return *this;
    }
    FPS integral() const { return FPS(*this).integ_inplace(); }

    FPS& log_inplace(int deg = -1) {
        int n = this->size();
        assert(n > 0 && this->a[0] == 1);
        if (deg == -1) deg = n;
        if (deg < n) this->resize(deg);
        FPS f_inv = this->inverse();
        this->differential_inplace();
        *this *= f_inv;
        this->resize(deg);
        this->integral_inplace();
        return *this;
    }  
    FPS log(const int deg = -1) const { return FPS(*this).log_inplace(deg); }

    FPS& exp_inplace(int deg = -1) {
        if (!any) {
            int n = this->size();
            assert(n > 0 && (*this)[0] == 0);
            if (deg == -1) deg = n;
            assert(deg >= 0);
            FPS g({1}), g_fft;
            this->resize(deg);
            this->a[0] = 1;
            FPS h_drv = this->differential();
            for (int m = 1; m < deg; m *= 2) {
                FPS f_fft(vector<mint<MOD>>(this->a.begin(), this->a.begin() + m));
                f_fft.resize(2*m), f_fft.ntt(false);
                mint<MOD> invm = m; invm = invm.inv();

                if (m > 1) {
                    FPS _f(m);
                    for(int i = 0; i < m; i++) _f[i] = f_fft[i] * g_fft[i];
                    _f.ntt(true);
                    _f.a.erase(_f.a.begin(), _f.a.begin() + m/2);
                    _f.resize(m), _f.ntt(false);
                    for(int i = 0; i < m; i++) _f[i] *= g_fft[i];
                    _f.ntt(true);
                    _f.resize(m/2);
                    for (int i = 0; i < m/2; i++) _f[i] = -_f[i];
                    g.a.insert(g.a.end(), _f.a.begin(), _f.a.begin() + m/2);
                }

                FPS t(vector<mint<MOD>>(this->a.begin(), this->a.begin() + m)); 
                t.differential_inplace();
                {
                    FPS r(vector<mint<MOD>>(h_drv.a.begin(), h_drv.a.begin() + m-1));
                    r.resize(m); r.ntt(false);
                    for (int i = 0; i < m; i++) r.a[i] *= f_fft.a[i];
                    r.ntt(true);
                    t -= r;
                    t.a.insert(t.a.begin(), t.a.back()); t.a.pop_back();
                }

                t.resize(2*m); t.ntt(false); 
                g_fft = g; g_fft.resize(2*m); g_fft.ntt(false);
                for (int i = 0; i < 2*m; i++) t.a[i] *= g_fft.a[i];
                t.ntt(true);
                t.resize(m);
                
                FPS v(vector<mint<MOD>>(this->a.begin() + m, this->a.begin() + min(deg, 2*m))); v.resize(m);
                t.a.insert(t.a.begin(), m-1, 0); t.a.push_back(0);
                t.integral_inplace();
                for (int i = 0; i < m; i++) v.a[i] -= t.a[m+i];

                v.resize(2*m); v.ntt(false);
                for (int i = 0; i < 2*m; i++) v.a[i] *= f_fft.a[i];
                v.ntt(true);
                v.resize(m);

                for (int i = 0; i < min(deg-m,m); i++) this->a[m+i] = v.a[i];
            }
            return *this;
        }
        else {
            assert(this->size() == 0 || this->a[0] == 0);
            if (deg == -1) deg = (int)this->size();
            FPS r({1});
            for (int i = 1; i < deg; i <<= 1) {
                r = (r*(this->low(i << 1)+1-r.log(i << 1))).low(i << 1);
            }
            return *this = r.low(deg);
        }
    }
    FPS exp(const int deg = -1) const { return FPS(*this).exp_inplace(deg); }

    FPS& pow_inplace(ll k, int deg = -1) {
        int n = this->size();
        if (deg == -1) deg = n;
        assert(deg >= 0);
        int l = 0;
        while (this->a[l] == 0) ++l;
        if (l > deg/k) return *this = FPS(deg);
        mint<MOD> ic = this->a[l].inv();
        mint<MOD> pc = this->a[l].pow(k);
        this->a.erase(this->a.begin(), this->a.begin() + l);
        *this *= ic.x;
        this->log_inplace();
        *this *= k;
        this->exp_inplace();
        *this *= pc.x;
        this->a.insert(this->a.begin(), l*k, 0);
        this->resize(deg);
        return *this;
    }
    FPS pow(const ll k, const int deg = -1) const { return FPS(*this).pow_inplace(k, deg); }

    FPS& sqrt_inplace(int deg = -1) {
        if (deg == -1) deg = this->size();
        int n = this->size(), z = 0;
        for(;z<n&&this->a[z]==0;z++);
        if(z==n) {this->resize(deg); return *this;}
        if(z%2) return *this = {};
        mint<MOD> w = this->a[z].sqrt();
        if(w*w!=this->a[z]) return *this = {};
        int s=deg-z/2;
        mint<MOD> az = this->a[z];
        this->a.erase(this->a.begin(),this->a.begin()+z);
        *this /= az;
        if (!any) *this = this->_sqrt(s);
        else {
            FPS g({1});
            mint<MOD> two_inv = mint<MOD>(2).inv();
            for (int i = 1; i < s; i*=2) {
                g.resize(i*2);
                g += (*this).low(i*2)*g.inverse();
                g *= two_inv;
            }
            *this = g.low(s);
        }
        *this *= w;
        this->a.insert(this->a.begin(),z/2,0);
        return *this;
    }
    FPS sqrt(int deg = -1) const { return FPS(*this).sqrt_inplace(deg); }

    FPS& shift_inplace(const mint<MOD> &c) {
        int n = this->size();
        Fact fc(n);
        for (int i = 0; i < n; i++) this->a[i] *= fc.FACT[i];
        reverse(this->a.begin(), this->a.end());
        FPS g(n);
        mint<MOD> cp = 1; 
        for (int i = 0; i < n; i++) g[i] = cp * fc.IFACT[i], cp *= c;
        this->convolution_inplace(g);
        this->a.resize(n);
        reverse(this->a.begin(), this->a.end());
        for (int i = 0; i < n; i++) this->a[i] *= fc.IFACT[i];
        return *this;
    }
    FPS shift(const mint<MOD> &c) const { return FPS(*this).shift_inplace(c); }

    vector<mint<MOD>> multipoint_evaluation(const vector<mint<MOD>> &xs) {
        auto g = subproduct_tree(xs);
        int m = (int) xs.size(), k = (int) g.size() / 2;
        g[1] = (*this) % g[1];
        for(int i = 2; i < k + m; i++) g[i] = g[i >> 1] % g[i];
        vector<mint<MOD>> ys(m);
        for(int i = 0; i < m; i++) {
            ys[i] = (g[k + i].size() == 0 ? mint<MOD>(0) : g[k + i][0]);
        }
        return ys;
    }
    vector<mint<MOD>> multipoint_evaluation(const FPS &xs) {
        return multipoint_evaluation(xs.a);
    }

    mint<MOD> &operator[](int x) {
        assert(0 <= x && x < (int)this->a.size());
        return a[x];
    }
    
    friend std::ostream &operator<<(std::ostream &os, const FPS &p) {
        os << "[ ";
        for (int i = 0; i < p.size(); ++i) {
            os << p.a[i] << " ";
        }
        os << "]";
        return os;
    }
};

int main() {
    int n,k;cin >> n >> k;
    FormalPowerSeries<MOD> s(k),t(k);
    mint<MOD> fact = 1;
    for (int i = 1; i < k; i++) {
        fact *= i;
        s[i] = mint<MOD>(i).pow(n)/fact;
    }
    fact = 1;
    t[0] = 1;
    for (int i = 1; i < k; i++) {
        fact *= i;
        if (i & 1) {
            t[i] = mint<MOD>(-1)/fact;
        }
        else {
            t[i] = fact.inv();
        }
    }
    mint<MOD> ans = 0;
    s *= t;
    fact = 1;
    for (int m = 1; m <= k; m+=2) {
        if (m!=1) fact *= (m-1);
        fact *= m;
        ans += s[k-m]*fact.inv();
    }
    for (int i = 2; i <= k; i++) {
        ans *= i;
    }
    cout << ans << endl;
    return 0;
}
0