結果
| 問題 |
No.1321 塗るめた
|
| コンテスト | |
| ユーザー |
nonon
|
| 提出日時 | 2024-05-01 19:48:04 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 206 ms / 2,000 ms |
| コード長 | 16,020 bytes |
| コンパイル時間 | 3,436 ms |
| コンパイル使用メモリ | 224,888 KB |
| 最終ジャッジ日時 | 2025-02-21 10:03:32 |
|
ジャッジサーバーID (参考情報) |
judge4 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 45 |
ソースコード
#include<bits/stdc++.h>
using namespace std;
template<long long mod_>
struct modint
{
modint():value(0){}
modint(long long v)
{
long long x=(long long)(v%m());
if(x<0)x+=m();
value=x;
}
static modint raw(long long v)
{
modint x;
x.value=v;
return x;
}
static constexpr long long mod()noexcept{return m();}
long long val()const{return value;}
modint& operator++()
{
value++;
if(value==m())value=0;
return *this;
}
modint& operator--()
{
if(value==0)value=m();
value--;
return *this;
}
modint operator++(int)
{
modint res=*this;
++*this;
return res;
}
modint operator--(int)
{
modint res=*this;
--*this;
return res;
}
modint& operator+=(const modint& a)
{
value+=a.value;
if(value>=m())value-=m();
return *this;
}
modint& operator-=(const modint& a)
{
value-=a.value;
if(value<0)value+=m();
return *this;
}
modint& operator*=(const modint& a)
{
unsigned long long x=value;
x*=a.value;
x%=m();
if(x<0)x+=m();
value=x;
return *this;
}
modint& operator/=(const modint& a)
{
return *this=(*this)*a.inv();
}
modint operator+()const{return *this;}
modint operator-()const{return modint()-*this;}
modint pow(long long n)const
{
modint x=*this,res=1;
while(n)
{
if(n&1)res*=x;
x*=x;
n>>=1;
}
return res;
}
modint inv()const
{
long long a=value,b=m(),u=1,v=0;
while(b)
{
long long t=a/b;
a-=t*b;
swap(a,b);
u-=t*v;
swap(u,v);
}
return modint(u);
}
friend modint operator+(const modint& a, const modint& b)
{
modint res=a;
res+=b;
return res;
}
friend modint operator-(const modint& a, const modint& b)
{
modint res=a;
res-=b;
return res;
}
friend modint operator*(const modint& a, const modint& b)
{
modint res=a;
res*=b;
return res;
}
friend modint operator/(const modint& a, const modint& b)
{
modint res=a;
res/=b;
return res;
}
friend bool operator==(const modint& a, const modint& b)
{
return a.value==b.value;
}
friend bool operator!=(const modint& a, const modint& b)
{
return a.value!=b.value;
}
private:
long long value;
static constexpr long long m(){return mod_;}
};
template<typename mint>
struct Number_Theoretic_Transform
{
static vector<mint>dw,dw_inv;
static int log;
static mint root;
static void ntt(vector<mint>& f)
{
init();
const int n=f.size();
for(int m=n;m>>=1;)
{
mint w=1;
for(int s=0,k=0;s<n;s+=(m<<1))
{
for(int i=s,j=s+m;i<s+m;i++,j++)
{
mint x=f[i],y=f[j]*w;
f[i]=x+y,f[j]=x-y;
}
w*=dw[__builtin_ctz(++k)];
}
}
}
static void intt(vector<mint>& f, bool flag=true)
{
init();
const int n=f.size();
for(int m=1;m<n;m<<=1)
{
mint w=1;
for(int s=0,k=0;s<n;s+=(m<<1))
{
for(int i=s,j=s+m;i<s+m;i++,j++)
{
mint x=f[i],y=f[j];
f[i]=x+y,f[j]=(x-y)*w;
}
w*=dw_inv[__builtin_ctz(++k)];
}
}
if(flag)
{
mint cef=mint(n).inv();
for(int i=0;i<n;i++)f[i]*=cef;
}
}
private:
Number_Theoretic_Transform()=default;
static void init()
{
if(!dw.empty())return;
long long mod=998244353;
long long tmp=mod-1;
log=1;
while(tmp%2==0)
{
tmp>>=1;
log++;
}
dw.resize(log);
dw_inv.resize(log);
for(int i=0;i<log;i++)
{
dw[i]=-root.pow((mod-1)>>(i+2));
dw_inv[i]=dw[i].inv();
}
}
};
template<typename mint>
vector<mint>Number_Theoretic_Transform<mint>::dw=vector<mint>();
template<typename mint>
vector<mint>Number_Theoretic_Transform<mint>::dw_inv=vector<mint>();
template<typename mint>
int Number_Theoretic_Transform<mint>::log=0;
template<typename mint>
mint Number_Theoretic_Transform<mint>::root=mint(3);
template<typename mint>
struct Formal_Power_Series:vector<mint>
{
using FPS=Formal_Power_Series;
using vector<mint>::vector;
using NTT=Number_Theoretic_Transform<mint>;
void ntt(){NTT::ntt(*this);}
void intt(bool flag=true){NTT::intt(*this,flag);}
FPS &operator+=(const mint& r)
{
if(this->empty())this->resize(1);
(*this)[0]+=r;
return *this;
}
FPS &operator-=(const mint& r)
{
if(this->empty())this->resize(1);
(*this)[0]-=r;
return *this;
}
FPS &operator*=(const mint& r)
{
for(mint &x:*this)x*=r;
return *this;
}
FPS &operator/=(const mint& r)
{
mint invr=r.inv();
for(mint &x:*this)x*=invr;
return *this;
}
FPS operator+(const mint& r)const{return FPS(*this)+=r;}
FPS operator-(const mint& r)const{return FPS(*this)-=r;}
FPS operator*(const mint& r)const{return FPS(*this)*=r;}
FPS operator/(const mint& r)const{return FPS(*this)/=r;}
FPS& operator+=(const FPS& f)
{
if(this->size()<f.size())this->resize(f.size());
for(int i=0;i<(int)f.size();i++)(*this)[i]+=f[i];
return *this;
}
FPS& operator-=(const FPS& f)
{
if(this->size()<f.size())this->resize(f.size());
for(int i=0;i<(int)f.size();i++)(*this)[i]-=f[i];
return *this;
}
FPS& operator*=(const FPS& f)
{
*this=convolution(*this,f);
return *this;
}
FPS& operator/=(const FPS& f)
{
return *this*=f.inv();
}
FPS& operator%=(const FPS& f)
{
*this-=this->div(f)*f;
this->shrink();
return *this;
}
FPS operator+(const FPS& f)const{return FPS(*this)+=f;}
FPS operator-(const FPS& f)const{return FPS(*this)-=f;}
FPS operator*(const FPS& f)const{return FPS(*this)*=f;}
FPS operator/(const FPS& f)const{return FPS(*this)/=f;}
FPS operator%(const FPS& f)const{return FPS(*this)%=f;}
FPS operator-()const
{
FPS res(this->size());
for(int i=0;i<(int)this->size();i++)res[i]-=(*this)[i];
return res;
}
FPS div(FPS f)
{
if(this->size()<f.size())return FPS{};
int n=this->size()-f.size()+1;
return (rev().pre(n)*f.rev().inv(n)).pre(n).rev(n);
}
FPS pre(int deg)const
{
return FPS(begin(*this),begin(*this)+min((int)this->size(),deg));
}
FPS rev(int deg=-1)const
{
FPS res(*this);
if(deg!=-1)res.resize(deg,0);
reverse(begin(res),end(res));
return res;
}
void shrink()
{
while(!this->empty()&&this->back()==0)this->pop_back();
}
FPS dot(FPS f)const
{
int n=min(this->size(),f.size());
FPS res(n);
for(int i=0;i<n;i++)res[i]=(*this)[i]*f[i];
return res;
}
FPS operator<<(int deg)const
{
FPS res(*this);
res.insert(res.begin(),deg,0);
return res;
}
FPS& operator<<=(int deg)
{
return *this=*this<<(deg);
}
FPS operator>>(int deg)const
{
if((int)this->size()<=deg)return{};
FPS res(*this);
res.erase(res.begin(),res.begin()+deg);
return res;
}
FPS& operator>>=(int deg)
{
return *this=*this>>(deg);
}
mint operator()(const mint& r)
{
mint res=0,powr=1;
for(auto x:*this)
{
res+=x*powr;
powr*=r;
}
return res;
}
FPS diff()const
{
int n=this->size();
FPS res(max(0,n-1));
for(int i=1;i<n;i++)
{
res[i-1]=(*this)[i]*i;
}
return res;
}
FPS integral()const
{
int n=this->size();
FPS res(n+1);
res[0]=0;
for(int i=0;i<n;i++)
{
res[i+1]=(*this)[i]/(i+1);
}
return res;
}
FPS inv(int deg=-1)const
{
assert(((*this)[0])!=(0));
int n=this->size();
if(deg==-1)deg=n;
FPS res(deg);
res[0]={(*this)[0].inv()};
for(int d=1;d<deg;d<<=1)
{
FPS f(d<<1),g(d<<1);
for(int j=0;j<min(n,2*d);j++)f[j]=(*this)[j];
for(int j=0;j<d;j++)g[j]=res[j];
f.ntt();
g.ntt();
f=f.dot(g);
f.intt();
for(int j=0;j<d;j++)f[j]=0;
f.ntt();
f=f.dot(g);
f.intt();
for(int j=d;j<min(2*d,deg);j++)res[j]=-f[j];
}
return res;
}
FPS exp(int deg=-1)const
{
assert((*this)[0]==0);
if(deg==-1)deg=this->size();
vector<mint>inv;
inv.reserve(deg+1);
inv.push_back(mint::raw(0));
inv.push_back(mint::raw(1));
auto inplace_integral=[&](FPS& f)->void
{
int n=f.size();
long long mod=mint::mod();
while(inv.size()<=f.size())
{
int i=inv.size();
inv.push_back((-inv[mod%i])*(mod/i));
}
f.insert(begin(f),mint::raw(0));
for(int i=1;i<=n;i++)f[i]*=inv[i];
};
auto inplace_diff=[](FPS& f)->void
{
if(f.empty())return;
f.erase(begin(f));
mint cef=1;
for(int i=0;i<(int)f.size();i++)
{
f[i]*=cef;
cef++;
}
};
FPS b={1,1<this->size()?(*this)[1]:0};
FPS c={1},z1,z2={1,1};
for(int m=2;m<deg;m<<=1)
{
FPS y=b;
y.resize(2*m);
y.ntt();
z1=z2;
FPS z(m);
z=y.dot(z1);
z.intt();
fill(begin(z),begin(z)+m/2,mint::raw(0));
z.ntt();
z=z.dot(-z1);
z.intt();
c.insert(end(c),begin(z)+m/2,end(z));
z2=c;
z2.resize(2*m);
z2.ntt();
FPS x(begin(*this),begin(*this)+min(int(this->size()),m));
inplace_diff(x);
x.push_back(mint::raw(0));
x.ntt();
x=x.dot(y);
x.intt();
x-=b.diff();
x.resize(2*m);
for(int i=0;i<m-1;i++)x[m+i]=x[i],x[i]=0;
x.ntt();
x=x.dot(z2);
x.intt();
x.pop_back();
inplace_integral(x);
for(int i=m;i<min(int(this->size()),2*m);i++)x[i]+=(*this)[i];
fill(begin(x),begin(x)+m,mint::raw(0));
x.ntt();
x=x.dot(y);
x.intt();
b.insert(end(b),begin(x)+m,end(x));
}
return FPS{begin(b),begin(b)+deg};
}
FPS log(int deg=-1)const
{
assert((*this)[0]==1);
int n=this->size();
if(deg==-1)deg=n;
return (this->diff()*this->inv()).pre(deg-1).integral();
}
FPS pow(long long k, int deg=-1)const
{
if(deg==-1)deg=this->size();
if(k==0)
{
FPS res(deg);
res[0]=mint::raw(1);
return res;
}
FPS res=*this;
int cnt0=0;
while(cnt0<res.size()&&res[cnt0]==0)cnt0++;
if (cnt0>(deg-1)/k)
{
FPS res(deg);
return res;
}
res=res>>cnt0;
deg-=cnt0*k;
res=((res/res[0]).log(deg)*k).exp(deg)*res[0].pow(k);
res=res<<(cnt0*k);
return res.pre(deg);
}
FPS taylor_shift(mint c)
{
int n=this->size();
FPS fact(n),fact_inv(n);
{ // calc fact and fact inv
fact[0]=1;
for(int i=1;i<n;i++)fact[i]=i*fact[i-1];
fact_inv[n-1]=fact[n-1].inv();
for(int i=n-1;i>=1;i--)fact_inv[i-1]=i*fact_inv[i];
}
FPS res(*this);
res=res.dot(fact);
res=res.rev();
FPS bs(n,mint::raw(1));
for(int i=1;i<n;i++)bs[i]=bs[i-1]*c*fact_inv[i]*fact[i-1];
res=(res*bs).pre(n);
res=res.rev();
res=res.dot(fact_inv);
return res;
}
vector<mint>multipoint_evaluation(vector<mint>&x)
{
if(x.empty())return{};
int m=x.size(),n=1;
if(this->size()==0){return vector<mint>(m,0);}
if(this->size()==1){return vector<mint>(m,(*this)[0]);}
while(m>n)n<<=1;
vector<FPS>f(n<<1,FPS({mint(1)}));
for(int i=0;i<m;i++)f[i+n]=FPS({-x[i],mint(1)});
for(int i=n-1;i>0;i--)f[i]=f[i<<1]*f[(i<<1)|1];
f[1]=(*this)%f[1];
for(int i=2;i<n+m;i++)f[i]=f[i>>1]%f[i];
vector<mint>res(m);
for(int i=0;i<m;i++)res[i]=(f[i+n].empty()?mint(0):f[i+n][0]);
return res;
}
private:
FPS convolution(FPS f, FPS g)
{
int n=f.size(),m=g.size();
if(n==0||m==0)return {};
int log=1;
while((1<<log)<n+m-1)log++;
int sz=1<<log;
f.resize(sz);
g.resize(sz);
f.ntt();
g.ntt();
mint inv=mint(sz).inv();
for(int i=0;i<sz;i++)f[i]*=g[i]*inv;
f.intt(0);
f.resize(n+m-1);
return f;
}
};
template<typename mint>
struct combination
{
combination(int n=0):inner_fac(1,1),inner_finv(1,1){init(n);}
mint fac(int n)
{
init(n);
return inner_fac[n];
}
mint finv(int n)
{
init(n);
return inner_finv[n];
}
mint inv(int n)
{
if(n==0)return 0;
init(n);
return inner_fac[n-1]*inner_finv[n];
}
mint C(int n, int r)
{
if(r<0)return 0;
if(n<0)
{
n=-n;
mint res=C(n-1+r,r);
if(r&1)res=-res;
return res;
}
if(n<r)return 0;
if(n<bound)
{
init(n);
return inner_fac[n]*inner_finv[n-r]*inner_finv[r];
}
init(r);
mint res=1;
for(int i=0;i<r;i++)res*=(n-i);
return res*inner_finv[r];
}
mint P(int n, int r)
{
if(n<0||r<0||n<r)return 0;
if(n<bound)
{
init(n);
return inner_fac[n]*inner_finv[n-r];
}
mint res=1;
for(int i=0;i<r;i++)res*=(n-i);
return res;
}
mint H(int n, int r)
{
return C(n-1+r,r);
}
private:
const int bound=1<<25;
vector<mint>inner_fac,inner_finv;
void init(int n)
{
int sz=inner_fac.size();
if(sz>n)return;
n=min(max(n,2*sz),bound);
inner_fac.resize(n+1);
inner_finv.resize(n+1);
for(int i=sz;i<=n;i++)inner_fac[i]=inner_fac[i-1]*i;
inner_finv[n]=inner_fac[n].inv();
for(int i=n;i>sz;i--)inner_finv[i-1]=inner_finv[i]*i;
}
};
using mint=modint<998244353>;
using FPS=Formal_Power_Series<mint>;
combination<mint>C;
vector<mint>stirling2(int n, int k)
{
FPS f(n-k+1);
for(int i=0;i<=n-k;i++)f[i]=C.finv(i+1);
f=f.pow(k);
f*=C.finv(k);
for(int i=0;i<=n-k;i++)f[i]*=C.fac(i+k);
vector<mint>res(n+1);
for(int i=k;i<=n;i++)res[i]=f[i-k];
return res;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N,M,K;
cin>>N>>M>>K;
auto S=stirling2(N,K);
mint ans=0;
for(int n=K;n<=N;n++)ans+=C.C(N,n)*C.C(M,K)*mint(M).pow(N-n)*C.fac(K)*S[n];
cout<<ans.val()<<endl;
}
nonon