結果

問題 No.1100 Boxes
ユーザー 37zigen37zigen
提出日時 2020-05-18 16:30:29
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
TLE  
(最新)
AC  
(最初)
実行時間 -
コード長 9,421 bytes
コンパイル時間 1,896 ms
コンパイル使用メモリ 125,128 KB
実行使用メモリ 83,956 KB
最終ジャッジ日時 2024-10-01 22:10:46
合計ジャッジ時間 10,617 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 5 ms
18,460 KB
testcase_01 AC 4 ms
11,640 KB
testcase_02 AC 5 ms
11,600 KB
testcase_03 AC 8 ms
11,684 KB
testcase_04 AC 5 ms
11,508 KB
testcase_05 AC 5 ms
11,568 KB
testcase_06 AC 5 ms
11,524 KB
testcase_07 AC 5 ms
11,520 KB
testcase_08 AC 5 ms
11,640 KB
testcase_09 AC 5 ms
11,648 KB
testcase_10 AC 5 ms
11,680 KB
testcase_11 AC 5 ms
11,512 KB
testcase_12 AC 5 ms
11,556 KB
testcase_13 AC 5 ms
11,688 KB
testcase_14 AC 5 ms
11,664 KB
testcase_15 AC 5 ms
11,564 KB
testcase_16 AC 5 ms
11,576 KB
testcase_17 AC 6 ms
11,596 KB
testcase_18 AC 8 ms
11,808 KB
testcase_19 AC 32 ms
12,500 KB
testcase_20 AC 63 ms
13,552 KB
testcase_21 AC 64 ms
13,316 KB
testcase_22 AC 534 ms
25,932 KB
testcase_23 AC 545 ms
27,168 KB
testcase_24 AC 1,160 ms
43,640 KB
testcase_25 TLE -
testcase_26 TLE -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
testcase_30 -- -
testcase_31 -- -
testcase_32 -- -
testcase_33 -- -
testcase_34 -- -
testcase_35 -- -
testcase_36 -- -
testcase_37 -- -
testcase_38 -- -
testcase_39 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <string>
#include <iostream>
#include <vector>
#include <cassert>
#include <random>
#include <algorithm>
#include <deque>
#include <cstring>
#include <time.h>
#include <cstdio>

template<typename T=int>inline T get(){
  char c=getchar(); bool neg=(c=='-');
  T res=neg?0:c-'0'; while(isdigit(c=getchar()))res=res*10+(c-'0');
  return neg?-res:res;
}

template<typename T=int>inline void put(T x,char c='\n'){
  if(x==0)putchar('0');
  else{
    if(x<0) std::putchar('-'),x*=-1;
    int d[20],i=0;
    while(x)d[i++]=x%10,x/=10;
    while(i--) std::putchar('0'+d[i]);
  } putchar(c);
}

constexpr long long p=998244353;

inline int ADD(int a,int b) {
  return a+b>=p?a+b-p:a+b;
}

inline int SUB(int a,int b) {
  return a-b<0?a-b+p:a-b;
}

int deg(std::vector<int> a){
  int ret=a.size()-1;
  while (ret>=0 && a[ret]==0) --ret;
  return ret;
}

std::vector<int> trim(std::vector<int> a,int n) {
  int asize=a.size();
  a.resize(n);
  for (int i=asize;i<n;++i) a[i]=0;
  return a;
}

std::vector<int> norm(std::vector<int> a) {
  while (a.size()>1 && a.back()==0) a.pop_back();
  return a;
}

//f->fx^(shift)
std::vector<int> shift(std::vector<int> &a,int shift) {
  std::vector<int> b(std::max(0,(int)a.size()+shift),0);
  for (int i=0;i<(int)b.size();++i) b[i]=(0<=i-shift&&i-shift<(int)a.size())?a[i-shift]:0;
  return b;
}

inline long long pow_mod(long long a,long long n) {
  long long ret=1;
  for (;n>0;n>>=1,a=a*a%p) if(n%2==1) ret=ret*a%p;
  return ret;
}

inline int inv(int a) {
  a%=p;
  if (a<0) a+=p;
  int u=p;
  int v=a;
  int s=0;
  int t=1;
  // sa=u
  // ta=v
  while (v!=0) {
    int q=u/v;
    s-=q*t;u-=q*v;
    std::swap(s,t);
    std::swap(u,v);
  }
  assert((long long)a*(s+p)%p==1);
  return s>=0?s:s+p;
}

std::vector<int> monic(std::vector<int> a) {
  if (deg(a)==-1) return a;
  a.resize(deg(a)+1);
  long long coe=inv(a[a.size()-1]);
  for (int i=0;i<(int)a.size();++i) {
    a[i]=(int)(coe*a[i]%p);
  }
  return a;
}

std::vector<int> add(std::vector<int> a,std::vector<int> b) {
  int n=std::max(a.size(),b.size());
  a.resize(n);b.resize(n);
  for (int i=0;i<n;++i) a[i]=ADD(a[i],b[i]);
  return a;
}

std::vector<int> subtract(std::vector<int> a,std::vector<int> b) {
  int n=std::max(a.size(),b.size());
  a.resize(n);b.resize(n);
  for (int i=0;i<n;++i) a[i]=ADD(a[i],p-b[i]);
  return a;
}

std::vector<int> mul_naive(std::vector<int> a,std::vector<int> b) {
  std::vector<int> ret(a.size()+b.size()-1,0);
  for (int i=0;i<(int)a.size();++i) {
    for (int j=0;j<(int)b.size();++j) {
      ret[i+j]=(int)((ret[i+j]+1LL*a[i]*b[j])%p);
    }
  }
  ret=norm(ret);
  return ret;
}

void fft_(int n,int g,int stride,std::vector<int> &from,std::vector<int> &to,bool flag){
  if (n==1) {
    if (flag) for (int i=0;i<stride;++i) to[i]=from[i];
    return;
  } else {
    int w=pow_mod(g,(p-1)/n);
    int mul=1;
    for (int i=0;i<n/2;++i) {
      for (int src=0;src<stride;++src) {
        const int A=from[src+stride*(i+  0)];
        const int B=from[src+stride*(i+n/2)];
        to[src+stride*(2*i+0)]=ADD(A,B);
        to[src+stride*(2*i+1)]=1LL*ADD(A,p-B)*mul%p;
      }
      mul=1LL*mul*w%p;
    }
    fft_(n/2,g,2*stride,to,from,!flag);
  }
}

void fft4_(int n,int g,int j,int stride,std::vector<int> &from,std::vector<int> &to,bool flag){
  int w=pow_mod(g,(p-1)/n);
  long long w1,w2,w3;
  int i,src,n0,n1,n2,n3,A,B,C,D,apc,amc,bpd,jbmd;
  
  while (n>2) {
    n0=0;
    n1=n/4;
    n2=n1+n1;
    n3=n1+n2;
    w1=1;
    for (i=0;i<n1;++i) {
      w2=w1*w1%p;
      w3=w1*w2%p;
      for (src=0;src<stride;++src) {
        A=from[src+stride*(i+n0)];
        B=from[src+stride*(i+n1)];
        C=from[src+stride*(i+n2)];
        D=from[src+stride*(i+n3)];
        apc=ADD(A,C);
        amc=SUB(A,C);
        bpd=ADD(B,D);
        jbmd=1LL*j*SUB(B,D)%p;
        to[src+stride*(4*i+0)]=ADD(apc,bpd);
        to[src+stride*(4*i+1)]=w1*(amc+p-jbmd)%p;
        to[src+stride*(4*i+2)]=w2*(A+C+p-bpd)%p;
        to[src+stride*(4*i+3)]=w3*(A+p-C+jbmd)%p;
      }
      w1=1LL*w1*w%p;
    }
    n/=4;
    stride*=4;
    flag=!flag;
    w=1LL*w*w%p;
    w=1LL*w*w%p;
    std::swap(to,from);
  }
  if (n<=2) fft_(n,g,stride,from,to,flag);
  if (from.size()>to.size()) std::swap(from,to);
}



std::vector<int> tmp_fft(1<<21);
void fft(std::vector<int> &a,int g) {
  fft4_(a.size(),g,pow_mod(g,(p-1)/4*3),1,a,tmp_fft,false);
}

//  (sx^p+u)(tx^p+v)
// =stx^(2p)+(sv+ut)x^p+uv
// =stx^(2p)+((s+u)(t+v)-(st-uv))x^p+uv
void mul_karatsuba(int a[],int b[],int c[],int res[],int n) {
  if (n<=8) {
    for (int i=0;i<2*n;++i) res[i]=0;
    for (int i=0;i<n;++i) for(int j=0;j<n;++j) res[i+j]=ADD(res[i+j],(int)(1LL*a[i]*b[j]%p));
    return;
  }
  int *x0=res;
  int *x1=res+n;
  int *x2=res+n*2;
  int *a0=a;
  int *a1=a+n/2;
  int *b0=b;
  int *b1=b+n/2;
  int *c0=c;
  int *c1=c+n/2;
  mul_karatsuba(a0,b0,c+n*2,x0,n/2);
  mul_karatsuba(a1,b1,c+n*2,x1,n/2);
  for (int i=0;i<n/2;++i) {
    c0[i]=ADD(a0[i],a1[i]);
    c1[i]=ADD(b0[i],b1[i]);
  }
  mul_karatsuba(c0,c1,c+n*2,x2,n/2);
  for (int i=0;i<n;++i) {
    x2[i]=ADD(ADD(x2[i],p-x0[i]),p-x1[i]);
  }
  for (int i=0;i<n;++i) {
    res[i+n/2]=ADD(res[i+n/2],x2[i]);
  }
}

std::vector<int> mul_fft(std::vector<int> a,std::vector<int> b) {
  int g=3;
  int n=1;
  int need=a.size()+b.size()-1;
  while (n<need) n*=2;
  a.resize(n);
  b.resize(n);
  fft(a,g);
  fft(b,g);
  int inv_n=inv(n);
  for (int i=0;i<n;++i) a[i]=(int)(1LL*a[i]*b[i]%p*inv_n%p);
  fft(a,inv(g));
  a.resize(need);
  return a;
}

std::vector<int> karatsuba(std::vector<int> a,std::vector<int> b) {
  int need=std::max(a.size(),b.size());
  int n=1;
  while (n<need) n*=2;
  std::vector<int> a_=trim(a,n);
  std::vector<int> b_=trim(b,n);
  std::vector<int> c(4*n);
  std::vector<int> res(4*n);
  mul_karatsuba(a_.data(),b_.data(),c.data(),res.data(),n);
  res.resize(a.size()+b.size()-1);
  return res;
}

std::vector<int> mul(std::vector<int> a,std::vector<int> b) {
  if (std::max(a.size(),b.size())<=64) {
    return karatsuba(a,b);
  } else {
    std::vector<int> ret=mul_fft(a,b);
    ret=norm(ret);
    return ret;
  } 
}

std::vector<int> mul(std::vector<int> a,int b) {
  int n=a.size();
  std::vector<int> c(n);
  for (int i=0;i<n;++i) c[i]=(int)(1LL*a[i]*(p+b)%p);
  return c;
}

std::vector<std::vector<std::vector<int>>> mul(std::vector<std::vector<std::vector<int>>> a,std::vector<std::vector<std::vector<int>>> b) {
  std::vector<std::vector<std::vector<int>>> ret(a.size(),std::vector<std::vector<int>>(b[0].size(),std::vector<int>()));
  for (int i=0;i<(int)a.size();++i) {
    for (int j=0;j<(int)b[i].size();++j) {
      for (int k=0;k<(int)a[i].size();++k) {
        ret[i][j]=add(ret[i][j],mul(a[i][k],b[k][j]));
      }
    }
  }
  return ret;
}

// f  <- -f(fg-1)+f
std::vector<int> inv(std::vector<int> g) {
  int n=g.size();
  std::vector<int> f={inv(g[0])};
  long long root=3;
  long long iroot=inv(3);
  for (int len=1;len<n;len*=2) {
    std::vector<int> f_fft=trim(f,2*len);
    std::vector<int> g_fft=trim(g,2*len);
    fft(f_fft,root);
    fft(g_fft,root);
    long long isize=inv(2*len);
    for (int i=0;i<2*len;++i) g_fft[i]=(int)(1LL*g_fft[i]*f_fft[i]%p*isize%p);
    fft(g_fft,iroot);
    for (int i=0;i<len;++i) g_fft[i]=0;
    
    fft(g_fft,root);
    for (int i=0;i<2*len;++i) g_fft[i]=(int)(1LL*g_fft[i]*f_fft[i]%p*isize%p);
    fft(g_fft,iroot);
    for (int i=0;i<len;++i) g_fft[i]=0;
    f.resize(std::min(n,2*len));
    for (int i=0;i<2*len;++i) {
      f[i]=ADD(f[i],p-g_fft[i]);
    }
  }
  return f;
}

std::vector<int> divide(std::vector<int> a,std::vector<int> b) {
  a=norm(a);
  b=norm(b);
  if (a.size()<b.size()) {
    std::vector<int> ret(1,0);
    return ret;
  }
  std::reverse(a.begin(),a.end());
  std::reverse(b.begin(),b.end());
  int n=a.size()-b.size()+1;
  a.resize(n);
  b.resize(n);
  a=mul(a,inv(b));
  a.resize(n);
  std::reverse(a.begin(),a.end());
  return a;
}

std::vector<int> mod(std::vector<int> a,std::vector<int> b) {
  return subtract(a,mul(b,divide(a,b)));
}

std::vector<int> differentiate(std::vector<int> a) {
  int n=a.size();
  for (int i=0;i+1<n;++i) a[i]=1LL*(i+1)*a[i+1]%p;
  a[n-1]=0;
  return a;
}

std::vector<int> integrate(std::vector<int> a) {
  for (int i=a.size()-1;i>=1;--i) a[i]=1LL*inv(i)*a[i-1]%p;
  a[0]=0;
  return a;
}

std::vector<int> log(std::vector<int> a) {
  assert(a[0]==1);
  return integrate(mul(differentiate(a),inv(a)));
}

std::vector<int> exp(std::vector<int> a) {
  assert(a[0]==0);
  int n=a.size();
  std::vector<int> b={1};
  for (int len=1;len<n;len*=2) {
    std::vector<int> tmp=subtract(trim(a,2*len),log(trim(b,2*len)));
    ++tmp[0];
    b=trim(mul(b,tmp),2*len);
  }
  return b;
}

std::vector<int> pow(std::vector<int> a,int n) {
  int s=0;
  while (s<(int)a.size() && a[s]==0) ++s;
  if (s==(int)a.size()) return a;
  a=shift(a,-s);
  int b=inv(a[0]);
  for (int i=0;i<(int)a.size();++i) a[i]=1LL*b*a[i]%p;
  a=log(a);
  for (int i=0;i<(int)a.size();++i) a[i]=1LL*n%p*a[i]%p;
  a=exp(a);
  b=pow_mod(inv(b),n%(p-1));
  for (int i=0;i<(int)a.size();++i) a[i]=1LL*b*a[i]%p;
  a=shift(a,s*n);
  return a;
}

void solve() {
  int N, K;
  scanf("%d%d",&N,&K);
  std::vector<int> f(N+1);
  int fac=1;
  for (int i=0;i<=N;++i) {
    f[i]=inv(fac);
    if (i<N) fac=1LL*fac*(i+1)%p;
  }
  f[0]=(1LL*f[0]+p-2)%p;
  f=pow(f,K);
  long long ans=(1LL*pow_mod(K,N)+p-1LL*f[N]*fac%p)%p*inv(2)%p;
  printf("%lld\n",ans);
}

int main() {
  solve();
}
0