結果
問題 | No.2459 Stampaholic (Hard) |
ユーザー | kotatsugame |
提出日時 | 2023-09-02 00:37:42 |
言語 | C++14 (gcc 12.3.0 + boost 1.83.0) |
結果 |
AC
|
実行時間 | 1,003 ms / 4,000 ms |
コード長 | 8,287 bytes |
コンパイル時間 | 2,437 ms |
コンパイル使用メモリ | 124,080 KB |
実行使用メモリ | 33,040 KB |
最終ジャッジ日時 | 2024-06-11 07:02:55 |
合計ジャッジ時間 | 14,812 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 11 ms
11,008 KB |
testcase_01 | AC | 989 ms
32,916 KB |
testcase_02 | AC | 228 ms
15,796 KB |
testcase_03 | AC | 11 ms
11,008 KB |
testcase_04 | AC | 12 ms
11,136 KB |
testcase_05 | AC | 11 ms
11,136 KB |
testcase_06 | AC | 11 ms
11,136 KB |
testcase_07 | AC | 11 ms
11,136 KB |
testcase_08 | AC | 461 ms
20,480 KB |
testcase_09 | AC | 234 ms
16,140 KB |
testcase_10 | AC | 976 ms
31,352 KB |
testcase_11 | AC | 485 ms
21,744 KB |
testcase_12 | AC | 999 ms
32,836 KB |
testcase_13 | AC | 966 ms
31,332 KB |
testcase_14 | AC | 238 ms
16,364 KB |
testcase_15 | AC | 994 ms
33,040 KB |
testcase_16 | AC | 1,001 ms
33,040 KB |
testcase_17 | AC | 974 ms
33,040 KB |
testcase_18 | AC | 981 ms
33,040 KB |
testcase_19 | AC | 1,003 ms
33,040 KB |
testcase_20 | AC | 12 ms
11,136 KB |
testcase_21 | AC | 926 ms
30,140 KB |
ソースコード
#include<iostream> #include<vector> #include<cassert> #include<atcoder/modint> #include<atcoder/convolution> using namespace std; using mint=atcoder::modint998244353; class FPS : public std::vector<mint> { public: using std::vector<mint>::vector; FPS(const std::initializer_list<mint> l) : std::vector<mint>::vector(l) {} inline FPS& operator=(const std::vector<mint> &&f) & noexcept { std::vector<mint>::operator=(std::move(f)); return *this; } inline FPS& operator=(const std::vector<mint> &f) & { std::vector<mint>::operator=(f); return *this; } inline const mint operator[](int n) const noexcept { return n <= deg() ? unsafe_get(n) : 0; } inline mint& operator[](int n) noexcept { ensure_deg(n); return unsafe_get(n); } inline int size() const noexcept { return std::vector<mint>::size(); } inline int deg() const noexcept { return int(this->size()) - 1; } inline void cut(int max_deg) noexcept { if (deg() > max_deg) this->resize(std::max(0, max_deg + 1)); } inline int normalize() { while (this->size() and this->back() == 0) this->pop_back(); return deg(); } inline FPS pre(int max_deg) const noexcept { return FPS(this->begin(), this->begin() + std::min(this->deg(), std::max(0, max_deg)) + 1); } inline FPS operator+() const { return FPS(*this); } FPS operator-() const { FPS f(*this); for (auto &e : f) e = mint::mod() - e; return f; } FPS& operator+=(const FPS &g) { ensure_deg(g.deg()); for (int i = 0; i <= g.deg(); ++i) unsafe_get(i) += g.unsafe_get(i); return *this; } FPS& operator+=(FPS &&g) { ensure_deg(g.deg()); for (int i = 0; i <= g.deg(); ++i) unsafe_get(i) += g.unsafe_get(i); return *this; } FPS& operator-=(const FPS &g) { ensure_deg(g.deg()); for (int i = 0; i <= g.deg(); ++i) unsafe_get(i) -= g.unsafe_get(i); return *this; } FPS& operator-=(FPS &&g) { ensure_deg(g.deg()); for (int i = 0; i <= g.deg(); ++i) unsafe_get(i) -= g.unsafe_get(i); return *this; } inline FPS& operator*=(const FPS &g) { (*this) = atcoder::convolution(std::move(*this), g); return *this; } inline FPS& operator*=(FPS &&g) { (*this) = atcoder::convolution(std::move(*this), std::move(g)); return *this; } inline FPS& operator*=(const mint x) { for (auto &e : *this) e *= x; return *this; } inline FPS operator+(FPS &&g) const { return FPS(*this) += std::move(g); } inline FPS operator-(FPS &&g) const { return FPS(*this) -= std::move(g); } inline FPS operator*(FPS &&g) const { return FPS(*this) *= std::move(g); } inline FPS operator+(const FPS &g) const { return FPS(*this) += g; } inline FPS operator-(const FPS &g) const { return FPS(*this) -= g; } inline FPS operator*(const FPS &g) const { return FPS(*this) *= g; } inline FPS operator*(const mint x) const { return FPS(*this) *= x; } inline friend FPS operator*(const mint x, const FPS &f) { return f * x; } inline friend FPS operator*(const mint x, FPS &&f) { return f *= x; } FPS& inv_inplace(const int max_deg) { FPS res { unsafe_get(0).inv() }; for (int k = 1; k <= max_deg;) { k *= 2; int d = 0; for (const auto &e : this->pre(k) * (res * res)) { res[d] = res[d] + res[d] - e; if (++d > k) break; } } res.cut(max_deg); (*this) = std::move(res); return *this; } private: inline void ensure_deg(int d) { if (deg() < d) this->resize(d + 1, 0); } inline const mint& unsafe_get(int i) const { return std::vector<mint>::operator[](i); } inline mint& unsafe_get(int i) { return std::vector<mint>::operator[](i); } }; mint fac[5<<17],invfac[5<<17]; vector<mint> bernoulli(int n) { FPS a(n + 1); for (int i = 0; i <= n; ++i) a[i] = invfac[i+1]; a.inv_inplace(n); for (int i = 2; i <= n; ++i) a[i] *= fac[i]; return a; } mint comb(int a,int b){return fac[a]*invfac[b]*invfac[a-b];} mint lagrange_interpolation(const vector<mint>&y,long long x_) { int N=y.size(); if(N==0)return mint::raw(0); if(x_<N)return y[x_]; vector<mint>L(N),R(N); mint x=x_; L[0]=mint::raw(1); for(int i=1;i<N;i++)L[i]=L[i-1]*(x-mint::raw(i-1)); R[N-1]=mint::raw(1); for(int i=N-1;i--;)R[i]=R[i+1]*(x-mint::raw(i+1)); mint ret=mint::raw(0); for(int i=0;i<N;i++) { mint now=L[i]*R[i]*invfac[i]*invfac[N-i-1]*y[i]; if(N-i&1)ret+=now; else ret-=now; } return ret; } pair<pair<int,int>,pair<int,int> >f(int H,int K) {//min(i,H-i+1,min(K,H-K+1)) pair<pair<int,int>,pair<int,int> >ret; int T=min(K,H-K+1); ret.second.first=T; ret.second.second=0; {//i=1..(H+1)//2 int UP=(H+1)/2; if(T<UP) { ret.second.second+=UP-T; } ret.first.first=min(UP,T); } {//i=(H+1)//2+1..H int UP=(H+1+1)/2-1; if(T<UP) { ret.second.second+=UP-T; } ret.first.second=min(UP,T); } return ret; } int H,W,N,K; mint invt; vector<mint>B; mint coef[5<<17]; mint S(int k,int n) {//Sum[i^k,{i,1,n}] mint ret=mint::raw(0); mint nn=mint::raw(1); for(int j=k;j>=0;j--) { nn*=mint::raw(n); ret+=comb(k+1,j)*B[j]*nn; } ret/=k+1; //cout<<"Sum[i^"<<k<<",{i,1,"<<n<<"}] = "<<ret.val()<<endl; return ret; } vector<mint>ff(int UP) { FPS f(N+1),g(N+1); mint t=mint::raw(1); for(int i=0;i<=N;i++) { t*=mint::raw(UP); f[i]=t*invfac[i+1]; g[i]=invfac[i+1]; } g.inv_inplace(N); f*=g; f=f.pre(N); vector<mint>ret(N+1); ret[0]=UP-1; for(int i=0;i<N;i++)ret[i+1]=f[i+1]*fac[i+1]; return ret; } mint g1(const vector<mint>&L,const vector<mint>&R) { mint ret=mint::raw(0); //vector<mint>L=ff(h+1),R=ff(w+1); for(int n=0;n<=N;n++)ret+=coef[n]*L[n]*R[n]; return ret; } mint g1(int h,int w) { vector<mint>L(N+2),R(N+2); L[0]=R[0]=mint::raw(1); for(int i=1;i<=N+1;i++) { L[i]=L[i-1]*mint::raw(h); R[i]=R[i-1]*mint::raw(w); } for(int i=0;i<=N+1;i++) { L[i]*=invfac[i]; R[i]*=invfac[i]; } L=atcoder::convolution(L,B); R=atcoder::convolution(R,B); mint ret=mint::raw(0); for(int n=0;n<=N;n++) { mint l=L[n+1],r=R[n+1]; l-=B[n+1]; r-=B[n+1]; l*=fac[n]; r*=fac[n]; ret+=coef[n]*l*r; } return ret; } mint g2(pair<int,int>h,pair<int,int>w) { vector<mint>F(N+5); F[0]=mint::raw(0); for(int i=1;i<N+5;i++) {//(1-i*w.first/t)^N mint now=(1-mint((long)i*w.first)*invt).pow(N); F[i]=F[i-1]+now; } mint ret=mint::raw(0); ret+=lagrange_interpolation(F,h.first); ret+=lagrange_interpolation(F,h.second); return ret*w.second; } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); { const int n=5<<17; fac[0]=mint::raw(1); for(int i=1;i<n;i++)fac[i]=fac[i-1]*mint::raw(i); invfac[n-1]=fac[n-1].inv(); for(int i=n-1;i--;)invfac[i]=invfac[i+1]*mint::raw(i+1); } cin>>H>>W>>N>>K; B=bernoulli(N+1); B[1]=-B[1]; for(int i=0;i<=N+1;i++)B[i]*=invfac[i]; invt=mint((long)(H-K+1)*(W-K+1)).inv(); { mint c=mint::raw(1); for(int i=0;i<=N;i++) { coef[i]=c*comb(N,i); c*=-invt; } } pair<pair<int,int>,pair<int,int> >h=f(H,K),w=f(W,K); mint ans=(long)H*W; /* vector<mint>L1=ff(h.first.first+1),L2=ff(h.first.second+1); vector<mint>R1=ff(w.first.first+1),R2=ff(w.first.second+1); ans-=g1(L1,R1); ans-=g1(L1,R2); ans-=g1(L2,R1); ans-=g1(L2,R2); */ ans-=g1(h.first.first,w.first.first); ans-=g1(h.first.first,w.first.second); ans-=g1(h.first.second,w.first.first); ans-=g1(h.first.second,w.first.second); ans-=g2(h.first,w.second); ans-=g2(w.first,h.second); if(h.second.second>0&&w.second.second>0) { ans-=(1-mint((long)h.second.first*w.second.first)*invt).pow(N)*h.second.second*w.second.second; } cout<<ans.val()<<endl; }