結果

問題 No.1321 塗るめた
ユーザー nononnonon
提出日時 2024-05-01 19:48:04
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 194 ms / 2,000 ms
コード長 16,020 bytes
コンパイル時間 3,336 ms
コンパイル使用メモリ 231,516 KB
実行使用メモリ 20,576 KB
最終ジャッジ日時 2024-05-01 19:48:12
合計ジャッジ時間 8,230 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
6,812 KB
testcase_01 AC 6 ms
6,816 KB
testcase_02 AC 2 ms
6,940 KB
testcase_03 AC 2 ms
6,940 KB
testcase_04 AC 2 ms
6,940 KB
testcase_05 AC 2 ms
6,940 KB
testcase_06 AC 2 ms
6,940 KB
testcase_07 AC 2 ms
6,940 KB
testcase_08 AC 2 ms
6,944 KB
testcase_09 AC 1 ms
6,940 KB
testcase_10 AC 1 ms
6,940 KB
testcase_11 AC 2 ms
6,944 KB
testcase_12 AC 77 ms
11,060 KB
testcase_13 AC 22 ms
6,940 KB
testcase_14 AC 39 ms
7,428 KB
testcase_15 AC 80 ms
11,536 KB
testcase_16 AC 84 ms
12,296 KB
testcase_17 AC 4 ms
6,940 KB
testcase_18 AC 168 ms
20,444 KB
testcase_19 AC 42 ms
7,424 KB
testcase_20 AC 75 ms
11,152 KB
testcase_21 AC 84 ms
11,800 KB
testcase_22 AC 164 ms
20,148 KB
testcase_23 AC 177 ms
19,760 KB
testcase_24 AC 159 ms
17,652 KB
testcase_25 AC 169 ms
20,576 KB
testcase_26 AC 167 ms
20,392 KB
testcase_27 AC 153 ms
17,328 KB
testcase_28 AC 158 ms
18,692 KB
testcase_29 AC 157 ms
17,128 KB
testcase_30 AC 194 ms
20,176 KB
testcase_31 AC 82 ms
11,672 KB
testcase_32 AC 74 ms
10,384 KB
testcase_33 AC 73 ms
10,252 KB
testcase_34 AC 84 ms
11,624 KB
testcase_35 AC 75 ms
10,548 KB
testcase_36 AC 5 ms
6,940 KB
testcase_37 AC 85 ms
11,684 KB
testcase_38 AC 84 ms
12,200 KB
testcase_39 AC 83 ms
12,176 KB
testcase_40 AC 82 ms
12,196 KB
testcase_41 AC 83 ms
11,556 KB
testcase_42 AC 2 ms
6,944 KB
testcase_43 AC 74 ms
10,380 KB
testcase_44 AC 83 ms
10,380 KB
testcase_45 AC 84 ms
11,500 KB
testcase_46 AC 21 ms
6,940 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<bits/stdc++.h>
using namespace std;
template<long long mod_>
struct modint
{
    modint():value(0){}
    modint(long long v)
    {
        long long x=(long long)(v%m());
        if(x<0)x+=m();
        value=x;
    }
    static modint raw(long long v)
    {
        modint x;
        x.value=v;
        return x;
    }
    static constexpr long long mod()noexcept{return m();}
    long long val()const{return value;}
    modint& operator++()
    {
        value++;
        if(value==m())value=0;
        return *this;
    }
    modint& operator--()
    {
        if(value==0)value=m();
        value--;
        return *this;
    }
    modint operator++(int)
    {
        modint res=*this;
        ++*this;
        return res;
    }
    modint operator--(int)
    {
        modint res=*this;
        --*this;
        return res;
    }
    modint& operator+=(const modint& a)
    {
        value+=a.value;
        if(value>=m())value-=m();
        return *this;
    }
    modint& operator-=(const modint& a)
    {
        value-=a.value;
        if(value<0)value+=m();
        return *this;
    }
    modint& operator*=(const modint& a)
    {
        unsigned long long x=value;
        x*=a.value;
        x%=m();
        if(x<0)x+=m();
        value=x;
        return *this;
    }
    modint& operator/=(const modint& a)
    {
        return *this=(*this)*a.inv();
    }
    modint operator+()const{return *this;}
    modint operator-()const{return modint()-*this;}
    modint pow(long long n)const
    {
        modint x=*this,res=1;
        while(n)
        {
            if(n&1)res*=x;
            x*=x;
            n>>=1;
        }
        return res;
    }
    modint inv()const
    {
        long long a=value,b=m(),u=1,v=0;
        while(b)
        {
            long long t=a/b;
            a-=t*b;
            swap(a,b);
            u-=t*v;
            swap(u,v);
        }
        return modint(u);
    }
    friend modint operator+(const modint& a, const modint& b)
    {
        modint res=a;
        res+=b;
        return res;
    }
    friend modint operator-(const modint& a, const modint& b)
    {
        modint res=a;
        res-=b;
        return res;
    }
    friend modint operator*(const modint& a, const modint& b)
    {
        modint res=a;
        res*=b;
        return res;
    }
    friend modint operator/(const modint& a, const modint& b)
    {
        modint res=a;
        res/=b;
        return res;
    }
    friend bool operator==(const modint& a, const modint& b)
    {
        return a.value==b.value;
    }
    friend bool operator!=(const modint& a, const modint& b)
    {
        return a.value!=b.value;
    }
private:
    long long value;
    static constexpr long long m(){return mod_;}
};
template<typename mint>
struct Number_Theoretic_Transform
{
    static vector<mint>dw,dw_inv;
    static int log;
    static mint root;
    static void ntt(vector<mint>& f)
    {
        init();
        const int n=f.size();
        for(int m=n;m>>=1;)
        {
            mint w=1;
            for(int s=0,k=0;s<n;s+=(m<<1))
            {
                for(int i=s,j=s+m;i<s+m;i++,j++)
                {
                    mint x=f[i],y=f[j]*w;
                    f[i]=x+y,f[j]=x-y;
                }
                w*=dw[__builtin_ctz(++k)];
            }
        }
    }
    static void intt(vector<mint>& f, bool flag=true)
    {
        init();
        const int n=f.size();
        for(int m=1;m<n;m<<=1)
        {
            mint w=1;
            for(int s=0,k=0;s<n;s+=(m<<1))
            {
                for(int i=s,j=s+m;i<s+m;i++,j++)
                {
                    mint x=f[i],y=f[j];
                    f[i]=x+y,f[j]=(x-y)*w;
                }
                w*=dw_inv[__builtin_ctz(++k)];
            }
        }
        if(flag)
        {
            mint cef=mint(n).inv();
            for(int i=0;i<n;i++)f[i]*=cef;
        }
    }
private:
    Number_Theoretic_Transform()=default;
    static void init()
    {
        if(!dw.empty())return;
        long long mod=998244353;
        long long tmp=mod-1;
        log=1;
        while(tmp%2==0)
        {
            tmp>>=1;
            log++;
        }
        dw.resize(log);
        dw_inv.resize(log);
        for(int i=0;i<log;i++)
        {
            dw[i]=-root.pow((mod-1)>>(i+2));
            dw_inv[i]=dw[i].inv();
        }
    }
};
template<typename mint>
vector<mint>Number_Theoretic_Transform<mint>::dw=vector<mint>();
template<typename mint>
vector<mint>Number_Theoretic_Transform<mint>::dw_inv=vector<mint>();
template<typename mint>
int Number_Theoretic_Transform<mint>::log=0;
template<typename mint>
mint Number_Theoretic_Transform<mint>::root=mint(3);
template<typename mint>
struct Formal_Power_Series:vector<mint>
{
    using FPS=Formal_Power_Series;
    using vector<mint>::vector;
    using NTT=Number_Theoretic_Transform<mint>;
    void ntt(){NTT::ntt(*this);}
    void intt(bool flag=true){NTT::intt(*this,flag);}
    FPS &operator+=(const mint& r)
    {
        if(this->empty())this->resize(1);
        (*this)[0]+=r;
        return *this;
    }
    FPS &operator-=(const mint& r)
    {
        if(this->empty())this->resize(1);
        (*this)[0]-=r;
        return *this;
    }
    FPS &operator*=(const mint& r)
    {
        for(mint &x:*this)x*=r;
        return *this;
    }
    FPS &operator/=(const mint& r)
    {
        mint invr=r.inv();
        for(mint &x:*this)x*=invr;
        return *this;
    }
    FPS operator+(const mint& r)const{return FPS(*this)+=r;}
    FPS operator-(const mint& r)const{return FPS(*this)-=r;}
    FPS operator*(const mint& r)const{return FPS(*this)*=r;}
    FPS operator/(const mint& r)const{return FPS(*this)/=r;}
    FPS& operator+=(const FPS& f)
    {
        if(this->size()<f.size())this->resize(f.size());
        for(int i=0;i<(int)f.size();i++)(*this)[i]+=f[i];
        return *this;
    }
    FPS& operator-=(const FPS& f)
    {
        if(this->size()<f.size())this->resize(f.size());
        for(int i=0;i<(int)f.size();i++)(*this)[i]-=f[i];
        return *this;
    }
    FPS& operator*=(const FPS& f)
    {
        *this=convolution(*this,f);
        return *this;
    }
    FPS& operator/=(const FPS& f)
    {
        return *this*=f.inv();
    }
    FPS& operator%=(const FPS& f)
    {
        *this-=this->div(f)*f;
        this->shrink();
        return *this;
    }
    FPS operator+(const FPS& f)const{return FPS(*this)+=f;}
    FPS operator-(const FPS& f)const{return FPS(*this)-=f;}
    FPS operator*(const FPS& f)const{return FPS(*this)*=f;}
    FPS operator/(const FPS& f)const{return FPS(*this)/=f;}
    FPS operator%(const FPS& f)const{return FPS(*this)%=f;}
    FPS operator-()const
    {
        FPS res(this->size());
        for(int i=0;i<(int)this->size();i++)res[i]-=(*this)[i];
        return res;
    }
    FPS div(FPS f)
    {
        if(this->size()<f.size())return FPS{};
        int n=this->size()-f.size()+1;
        return (rev().pre(n)*f.rev().inv(n)).pre(n).rev(n);
    }
    FPS pre(int deg)const
    {
        return FPS(begin(*this),begin(*this)+min((int)this->size(),deg));
    }
    FPS rev(int deg=-1)const
    {
        FPS res(*this);
        if(deg!=-1)res.resize(deg,0);
        reverse(begin(res),end(res));
        return res;
    }
    void shrink()
    {
        while(!this->empty()&&this->back()==0)this->pop_back();
    }
    FPS dot(FPS f)const
    {
        int n=min(this->size(),f.size());
        FPS res(n);
        for(int i=0;i<n;i++)res[i]=(*this)[i]*f[i];
        return res;
    }
    FPS operator<<(int deg)const
    {
        FPS res(*this);
        res.insert(res.begin(),deg,0);
        return res;
    }
    FPS& operator<<=(int deg)
    {
        return *this=*this<<(deg);
    }
    FPS operator>>(int deg)const
    {
        if((int)this->size()<=deg)return{};
        FPS res(*this);
        res.erase(res.begin(),res.begin()+deg);
        return res;
    }
    FPS& operator>>=(int deg)
    {
        return *this=*this>>(deg);
    }
    mint operator()(const mint& r)
    {
        mint res=0,powr=1;
        for(auto x:*this)
        {
            res+=x*powr;
            powr*=r;
        }
        return res;
    }
    FPS diff()const
    {
        int n=this->size();
        FPS res(max(0,n-1));
        for(int i=1;i<n;i++)
        {
            res[i-1]=(*this)[i]*i;
        }
        return res;
    }
    FPS integral()const
    {
        int n=this->size();
        FPS res(n+1);
        res[0]=0;
        for(int i=0;i<n;i++)
        {
            res[i+1]=(*this)[i]/(i+1);
        }
        return res;
    }
    FPS inv(int deg=-1)const
    {
        assert(((*this)[0])!=(0));
        int n=this->size();
        if(deg==-1)deg=n;
        FPS res(deg);
        res[0]={(*this)[0].inv()};
        for(int d=1;d<deg;d<<=1)
        {
            FPS f(d<<1),g(d<<1);
            for(int j=0;j<min(n,2*d);j++)f[j]=(*this)[j];
            for(int j=0;j<d;j++)g[j]=res[j];
            f.ntt();
            g.ntt();
            f=f.dot(g);
            f.intt();
            for(int j=0;j<d;j++)f[j]=0;
            f.ntt();
            f=f.dot(g);
            f.intt();
            for(int j=d;j<min(2*d,deg);j++)res[j]=-f[j];
        }
        return res;
    }
    FPS exp(int deg=-1)const
    {
        assert((*this)[0]==0);
        if(deg==-1)deg=this->size();
        vector<mint>inv;
        inv.reserve(deg+1);
        inv.push_back(mint::raw(0));
        inv.push_back(mint::raw(1));
        auto inplace_integral=[&](FPS& f)->void
        {
            int n=f.size();
            long long mod=mint::mod();
            while(inv.size()<=f.size())
            {
                int i=inv.size();
                inv.push_back((-inv[mod%i])*(mod/i));
            }
            f.insert(begin(f),mint::raw(0));
            for(int i=1;i<=n;i++)f[i]*=inv[i];
        };
        auto inplace_diff=[](FPS& f)->void
        {
            if(f.empty())return;
            f.erase(begin(f));
            mint cef=1;
            for(int i=0;i<(int)f.size();i++)
            {
                f[i]*=cef;
                cef++;
            }
        };
        FPS b={1,1<this->size()?(*this)[1]:0};
        FPS c={1},z1,z2={1,1};
        for(int m=2;m<deg;m<<=1)
        {
            FPS y=b;
            y.resize(2*m);
            y.ntt();
            z1=z2;
            FPS z(m);
            z=y.dot(z1);
            z.intt();
            fill(begin(z),begin(z)+m/2,mint::raw(0));
            z.ntt();
            z=z.dot(-z1);
            z.intt();
            c.insert(end(c),begin(z)+m/2,end(z));
            z2=c;
            z2.resize(2*m);
            z2.ntt();
            FPS x(begin(*this),begin(*this)+min(int(this->size()),m));
            inplace_diff(x);
            x.push_back(mint::raw(0));
            x.ntt();
            x=x.dot(y);
            x.intt();
            x-=b.diff();
            x.resize(2*m);
            for(int i=0;i<m-1;i++)x[m+i]=x[i],x[i]=0;
            x.ntt();
            x=x.dot(z2);
            x.intt();
            x.pop_back();
            inplace_integral(x);
            for(int i=m;i<min(int(this->size()),2*m);i++)x[i]+=(*this)[i];
            fill(begin(x),begin(x)+m,mint::raw(0));
            x.ntt();
            x=x.dot(y);
            x.intt();
            b.insert(end(b),begin(x)+m,end(x));
        }
        return FPS{begin(b),begin(b)+deg};
    }
    FPS log(int deg=-1)const
    {
        assert((*this)[0]==1);
        int n=this->size();
        if(deg==-1)deg=n;
        return (this->diff()*this->inv()).pre(deg-1).integral();
    }
    FPS pow(long long k, int deg=-1)const
    {
        if(deg==-1)deg=this->size();
        if(k==0)
        {
            FPS res(deg);
            res[0]=mint::raw(1);
            return res;
        }
        FPS res=*this;
        int cnt0=0;
        while(cnt0<res.size()&&res[cnt0]==0)cnt0++;
        if (cnt0>(deg-1)/k)
        {
            FPS res(deg);
            return res;
        }
        res=res>>cnt0;
        deg-=cnt0*k;
        res=((res/res[0]).log(deg)*k).exp(deg)*res[0].pow(k);
        res=res<<(cnt0*k);
        return res.pre(deg);
    }
    FPS taylor_shift(mint c)
    {
        int n=this->size();
        FPS fact(n),fact_inv(n);
        { // calc fact and fact inv
            fact[0]=1;
            for(int i=1;i<n;i++)fact[i]=i*fact[i-1];
            fact_inv[n-1]=fact[n-1].inv();
            for(int i=n-1;i>=1;i--)fact_inv[i-1]=i*fact_inv[i];
        }
        FPS res(*this);
        res=res.dot(fact);
        res=res.rev();
        FPS bs(n,mint::raw(1));
        for(int i=1;i<n;i++)bs[i]=bs[i-1]*c*fact_inv[i]*fact[i-1];
        res=(res*bs).pre(n);
        res=res.rev();
        res=res.dot(fact_inv);
        return res;
    }
    vector<mint>multipoint_evaluation(vector<mint>&x)
    {
        if(x.empty())return{};
        int m=x.size(),n=1;
        if(this->size()==0){return vector<mint>(m,0);}
        if(this->size()==1){return vector<mint>(m,(*this)[0]);}
        while(m>n)n<<=1;
        vector<FPS>f(n<<1,FPS({mint(1)}));
        for(int i=0;i<m;i++)f[i+n]=FPS({-x[i],mint(1)});
        for(int i=n-1;i>0;i--)f[i]=f[i<<1]*f[(i<<1)|1];
        f[1]=(*this)%f[1];
        for(int i=2;i<n+m;i++)f[i]=f[i>>1]%f[i];
        vector<mint>res(m);
        for(int i=0;i<m;i++)res[i]=(f[i+n].empty()?mint(0):f[i+n][0]);
        return res;
    }
private:
    FPS convolution(FPS f, FPS g)
    {
        int n=f.size(),m=g.size();
        if(n==0||m==0)return {};
        int log=1;
        while((1<<log)<n+m-1)log++;
        int sz=1<<log;
        f.resize(sz);
        g.resize(sz);
        f.ntt();
        g.ntt();
        mint inv=mint(sz).inv();
        for(int i=0;i<sz;i++)f[i]*=g[i]*inv;
        f.intt(0);
        f.resize(n+m-1);
        return f;
    }
};
template<typename mint>
struct combination
{
    combination(int n=0):inner_fac(1,1),inner_finv(1,1){init(n);}
    mint fac(int n)
    {
        init(n);
        return inner_fac[n];
    }
    mint finv(int n)
    {
        init(n);
        return inner_finv[n];
    }
    mint inv(int n)
    {
        if(n==0)return 0;
        init(n);
        return inner_fac[n-1]*inner_finv[n];
    }
    mint C(int n, int r)
    {
        if(r<0)return 0;
        if(n<0)
        {
            n=-n;
            mint res=C(n-1+r,r);
            if(r&1)res=-res;
            return res;
        }
        if(n<r)return 0;
        if(n<bound)
        {
            init(n);
            return inner_fac[n]*inner_finv[n-r]*inner_finv[r];
        }
        init(r);
        mint res=1;
        for(int i=0;i<r;i++)res*=(n-i);
        return res*inner_finv[r];
    }
    mint P(int n, int r)
    {
        if(n<0||r<0||n<r)return 0;
        if(n<bound)
        {
            init(n);
            return inner_fac[n]*inner_finv[n-r];
        }
        mint res=1;
        for(int i=0;i<r;i++)res*=(n-i);
        return res;
    }
    mint H(int n, int r)
    {
        return C(n-1+r,r);
    }
private:
    const int bound=1<<25;
    vector<mint>inner_fac,inner_finv;
    void init(int n)
    {
        int sz=inner_fac.size();
        if(sz>n)return;
        n=min(max(n,2*sz),bound);
        inner_fac.resize(n+1);
        inner_finv.resize(n+1);
        for(int i=sz;i<=n;i++)inner_fac[i]=inner_fac[i-1]*i;
        inner_finv[n]=inner_fac[n].inv();
        for(int i=n;i>sz;i--)inner_finv[i-1]=inner_finv[i]*i;
    }
};
using mint=modint<998244353>;
using FPS=Formal_Power_Series<mint>;
combination<mint>C;
vector<mint>stirling2(int n, int k)
{
    FPS f(n-k+1);
    for(int i=0;i<=n-k;i++)f[i]=C.finv(i+1);
    f=f.pow(k);
    f*=C.finv(k);
    for(int i=0;i<=n-k;i++)f[i]*=C.fac(i+k);
    vector<mint>res(n+1);
    for(int i=k;i<=n;i++)res[i]=f[i-k];
    return res;
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int N,M,K;
    cin>>N>>M>>K;
    auto S=stirling2(N,K);
    mint ans=0;
    for(int n=K;n<=N;n++)ans+=C.C(N,n)*C.C(M,K)*mint(M).pow(N-n)*C.fac(K)*S[n];
    cout<<ans.val()<<endl;
}
0