結果
問題 | No.1100 Boxes |
ユーザー |
|
提出日時 | 2020-05-18 16:30:29 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
TLE
(最新)
AC
(最初)
|
実行時間 | - |
コード長 | 9,421 bytes |
コンパイル時間 | 2,348 ms |
コンパイル使用メモリ | 124,556 KB |
最終ジャッジ日時 | 2025-01-10 12:53:34 |
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 22 RE * 1 TLE * 13 |
コンパイルメッセージ
main.cpp: In function ‘void solve()’: main.cpp:380:8: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result] 380 | scanf("%d%d",&N,&K); | ~~~~~^~~~~~~~~~~~~~
ソースコード
#include <string>#include <iostream>#include <vector>#include <cassert>#include <random>#include <algorithm>#include <deque>#include <cstring>#include <time.h>#include <cstdio>template<typename T=int>inline T get(){char c=getchar(); bool neg=(c=='-');T res=neg?0:c-'0'; while(isdigit(c=getchar()))res=res*10+(c-'0');return neg?-res:res;}template<typename T=int>inline void put(T x,char c='\n'){if(x==0)putchar('0');else{if(x<0) std::putchar('-'),x*=-1;int d[20],i=0;while(x)d[i++]=x%10,x/=10;while(i--) std::putchar('0'+d[i]);} putchar(c);}constexpr long long p=998244353;inline int ADD(int a,int b) {return a+b>=p?a+b-p:a+b;}inline int SUB(int a,int b) {return a-b<0?a-b+p:a-b;}int deg(std::vector<int> a){int ret=a.size()-1;while (ret>=0 && a[ret]==0) --ret;return ret;}std::vector<int> trim(std::vector<int> a,int n) {int asize=a.size();a.resize(n);for (int i=asize;i<n;++i) a[i]=0;return a;}std::vector<int> norm(std::vector<int> a) {while (a.size()>1 && a.back()==0) a.pop_back();return a;}//f->fx^(shift)std::vector<int> shift(std::vector<int> &a,int shift) {std::vector<int> b(std::max(0,(int)a.size()+shift),0);for (int i=0;i<(int)b.size();++i) b[i]=(0<=i-shift&&i-shift<(int)a.size())?a[i-shift]:0;return b;}inline long long pow_mod(long long a,long long n) {long long ret=1;for (;n>0;n>>=1,a=a*a%p) if(n%2==1) ret=ret*a%p;return ret;}inline int inv(int a) {a%=p;if (a<0) a+=p;int u=p;int v=a;int s=0;int t=1;// sa=u// ta=vwhile (v!=0) {int q=u/v;s-=q*t;u-=q*v;std::swap(s,t);std::swap(u,v);}assert((long long)a*(s+p)%p==1);return s>=0?s:s+p;}std::vector<int> monic(std::vector<int> a) {if (deg(a)==-1) return a;a.resize(deg(a)+1);long long coe=inv(a[a.size()-1]);for (int i=0;i<(int)a.size();++i) {a[i]=(int)(coe*a[i]%p);}return a;}std::vector<int> add(std::vector<int> a,std::vector<int> b) {int n=std::max(a.size(),b.size());a.resize(n);b.resize(n);for (int i=0;i<n;++i) a[i]=ADD(a[i],b[i]);return a;}std::vector<int> subtract(std::vector<int> a,std::vector<int> b) {int n=std::max(a.size(),b.size());a.resize(n);b.resize(n);for (int i=0;i<n;++i) a[i]=ADD(a[i],p-b[i]);return a;}std::vector<int> mul_naive(std::vector<int> a,std::vector<int> b) {std::vector<int> ret(a.size()+b.size()-1,0);for (int i=0;i<(int)a.size();++i) {for (int j=0;j<(int)b.size();++j) {ret[i+j]=(int)((ret[i+j]+1LL*a[i]*b[j])%p);}}ret=norm(ret);return ret;}void fft_(int n,int g,int stride,std::vector<int> &from,std::vector<int> &to,bool flag){if (n==1) {if (flag) for (int i=0;i<stride;++i) to[i]=from[i];return;} else {int w=pow_mod(g,(p-1)/n);int mul=1;for (int i=0;i<n/2;++i) {for (int src=0;src<stride;++src) {const int A=from[src+stride*(i+ 0)];const int B=from[src+stride*(i+n/2)];to[src+stride*(2*i+0)]=ADD(A,B);to[src+stride*(2*i+1)]=1LL*ADD(A,p-B)*mul%p;}mul=1LL*mul*w%p;}fft_(n/2,g,2*stride,to,from,!flag);}}void fft4_(int n,int g,int j,int stride,std::vector<int> &from,std::vector<int> &to,bool flag){int w=pow_mod(g,(p-1)/n);long long w1,w2,w3;int i,src,n0,n1,n2,n3,A,B,C,D,apc,amc,bpd,jbmd;while (n>2) {n0=0;n1=n/4;n2=n1+n1;n3=n1+n2;w1=1;for (i=0;i<n1;++i) {w2=w1*w1%p;w3=w1*w2%p;for (src=0;src<stride;++src) {A=from[src+stride*(i+n0)];B=from[src+stride*(i+n1)];C=from[src+stride*(i+n2)];D=from[src+stride*(i+n3)];apc=ADD(A,C);amc=SUB(A,C);bpd=ADD(B,D);jbmd=1LL*j*SUB(B,D)%p;to[src+stride*(4*i+0)]=ADD(apc,bpd);to[src+stride*(4*i+1)]=w1*(amc+p-jbmd)%p;to[src+stride*(4*i+2)]=w2*(A+C+p-bpd)%p;to[src+stride*(4*i+3)]=w3*(A+p-C+jbmd)%p;}w1=1LL*w1*w%p;}n/=4;stride*=4;flag=!flag;w=1LL*w*w%p;w=1LL*w*w%p;std::swap(to,from);}if (n<=2) fft_(n,g,stride,from,to,flag);if (from.size()>to.size()) std::swap(from,to);}std::vector<int> tmp_fft(1<<21);void fft(std::vector<int> &a,int g) {fft4_(a.size(),g,pow_mod(g,(p-1)/4*3),1,a,tmp_fft,false);}// (sx^p+u)(tx^p+v)// =stx^(2p)+(sv+ut)x^p+uv// =stx^(2p)+((s+u)(t+v)-(st-uv))x^p+uvvoid mul_karatsuba(int a[],int b[],int c[],int res[],int n) {if (n<=8) {for (int i=0;i<2*n;++i) res[i]=0;for (int i=0;i<n;++i) for(int j=0;j<n;++j) res[i+j]=ADD(res[i+j],(int)(1LL*a[i]*b[j]%p));return;}int *x0=res;int *x1=res+n;int *x2=res+n*2;int *a0=a;int *a1=a+n/2;int *b0=b;int *b1=b+n/2;int *c0=c;int *c1=c+n/2;mul_karatsuba(a0,b0,c+n*2,x0,n/2);mul_karatsuba(a1,b1,c+n*2,x1,n/2);for (int i=0;i<n/2;++i) {c0[i]=ADD(a0[i],a1[i]);c1[i]=ADD(b0[i],b1[i]);}mul_karatsuba(c0,c1,c+n*2,x2,n/2);for (int i=0;i<n;++i) {x2[i]=ADD(ADD(x2[i],p-x0[i]),p-x1[i]);}for (int i=0;i<n;++i) {res[i+n/2]=ADD(res[i+n/2],x2[i]);}}std::vector<int> mul_fft(std::vector<int> a,std::vector<int> b) {int g=3;int n=1;int need=a.size()+b.size()-1;while (n<need) n*=2;a.resize(n);b.resize(n);fft(a,g);fft(b,g);int inv_n=inv(n);for (int i=0;i<n;++i) a[i]=(int)(1LL*a[i]*b[i]%p*inv_n%p);fft(a,inv(g));a.resize(need);return a;}std::vector<int> karatsuba(std::vector<int> a,std::vector<int> b) {int need=std::max(a.size(),b.size());int n=1;while (n<need) n*=2;std::vector<int> a_=trim(a,n);std::vector<int> b_=trim(b,n);std::vector<int> c(4*n);std::vector<int> res(4*n);mul_karatsuba(a_.data(),b_.data(),c.data(),res.data(),n);res.resize(a.size()+b.size()-1);return res;}std::vector<int> mul(std::vector<int> a,std::vector<int> b) {if (std::max(a.size(),b.size())<=64) {return karatsuba(a,b);} else {std::vector<int> ret=mul_fft(a,b);ret=norm(ret);return ret;}}std::vector<int> mul(std::vector<int> a,int b) {int n=a.size();std::vector<int> c(n);for (int i=0;i<n;++i) c[i]=(int)(1LL*a[i]*(p+b)%p);return c;}std::vector<std::vector<std::vector<int>>> mul(std::vector<std::vector<std::vector<int>>> a,std::vector<std::vector<std::vector<int>>> b) {std::vector<std::vector<std::vector<int>>> ret(a.size(),std::vector<std::vector<int>>(b[0].size(),std::vector<int>()));for (int i=0;i<(int)a.size();++i) {for (int j=0;j<(int)b[i].size();++j) {for (int k=0;k<(int)a[i].size();++k) {ret[i][j]=add(ret[i][j],mul(a[i][k],b[k][j]));}}}return ret;}// f <- -f(fg-1)+fstd::vector<int> inv(std::vector<int> g) {int n=g.size();std::vector<int> f={inv(g[0])};long long root=3;long long iroot=inv(3);for (int len=1;len<n;len*=2) {std::vector<int> f_fft=trim(f,2*len);std::vector<int> g_fft=trim(g,2*len);fft(f_fft,root);fft(g_fft,root);long long isize=inv(2*len);for (int i=0;i<2*len;++i) g_fft[i]=(int)(1LL*g_fft[i]*f_fft[i]%p*isize%p);fft(g_fft,iroot);for (int i=0;i<len;++i) g_fft[i]=0;fft(g_fft,root);for (int i=0;i<2*len;++i) g_fft[i]=(int)(1LL*g_fft[i]*f_fft[i]%p*isize%p);fft(g_fft,iroot);for (int i=0;i<len;++i) g_fft[i]=0;f.resize(std::min(n,2*len));for (int i=0;i<2*len;++i) {f[i]=ADD(f[i],p-g_fft[i]);}}return f;}std::vector<int> divide(std::vector<int> a,std::vector<int> b) {a=norm(a);b=norm(b);if (a.size()<b.size()) {std::vector<int> ret(1,0);return ret;}std::reverse(a.begin(),a.end());std::reverse(b.begin(),b.end());int n=a.size()-b.size()+1;a.resize(n);b.resize(n);a=mul(a,inv(b));a.resize(n);std::reverse(a.begin(),a.end());return a;}std::vector<int> mod(std::vector<int> a,std::vector<int> b) {return subtract(a,mul(b,divide(a,b)));}std::vector<int> differentiate(std::vector<int> a) {int n=a.size();for (int i=0;i+1<n;++i) a[i]=1LL*(i+1)*a[i+1]%p;a[n-1]=0;return a;}std::vector<int> integrate(std::vector<int> a) {for (int i=a.size()-1;i>=1;--i) a[i]=1LL*inv(i)*a[i-1]%p;a[0]=0;return a;}std::vector<int> log(std::vector<int> a) {assert(a[0]==1);return integrate(mul(differentiate(a),inv(a)));}std::vector<int> exp(std::vector<int> a) {assert(a[0]==0);int n=a.size();std::vector<int> b={1};for (int len=1;len<n;len*=2) {std::vector<int> tmp=subtract(trim(a,2*len),log(trim(b,2*len)));++tmp[0];b=trim(mul(b,tmp),2*len);}return b;}std::vector<int> pow(std::vector<int> a,int n) {int s=0;while (s<(int)a.size() && a[s]==0) ++s;if (s==(int)a.size()) return a;a=shift(a,-s);int b=inv(a[0]);for (int i=0;i<(int)a.size();++i) a[i]=1LL*b*a[i]%p;a=log(a);for (int i=0;i<(int)a.size();++i) a[i]=1LL*n%p*a[i]%p;a=exp(a);b=pow_mod(inv(b),n%(p-1));for (int i=0;i<(int)a.size();++i) a[i]=1LL*b*a[i]%p;a=shift(a,s*n);return a;}void solve() {int N, K;scanf("%d%d",&N,&K);std::vector<int> f(N+1);int fac=1;for (int i=0;i<=N;++i) {f[i]=inv(fac);if (i<N) fac=1LL*fac*(i+1)%p;}f[0]=(1LL*f[0]+p-2)%p;f=pow(f,K);long long ans=(1LL*pow_mod(K,N)+p-1LL*f[N]*fac%p)%p*inv(2)%p;printf("%lld\n",ans);}int main() {solve();}