結果
| 問題 |
No.2459 Stampaholic (Hard)
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2023-09-02 00:01:31 |
| 言語 | C++14 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 2,092 ms / 4,000 ms |
| コード長 | 7,790 bytes |
| コンパイル時間 | 3,025 ms |
| コンパイル使用メモリ | 118,172 KB |
| 実行使用メモリ | 36,992 KB |
| 最終ジャッジ日時 | 2025-01-03 13:40:14 |
| 合計ジャッジ時間 | 25,860 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 19 |
ソースコード
#include<iostream>
#include<vector>
#include<cassert>
#include<atcoder/modint>
#include<atcoder/convolution>
using namespace std;
using mint=atcoder::modint998244353;
class FPS : public std::vector<mint> {
public:
using std::vector<mint>::vector;
FPS(const std::initializer_list<mint> l) : std::vector<mint>::vector(l) {}
inline FPS& operator=(const std::vector<mint> &&f) & noexcept {
std::vector<mint>::operator=(std::move(f));
return *this;
}
inline FPS& operator=(const std::vector<mint> &f) & {
std::vector<mint>::operator=(f);
return *this;
}
inline const mint operator[](int n) const noexcept {
return n <= deg() ? unsafe_get(n) : 0;
}
inline mint& operator[](int n) noexcept {
ensure_deg(n);
return unsafe_get(n);
}
inline int size() const noexcept { return std::vector<mint>::size(); }
inline int deg() const noexcept { return int(this->size()) - 1; }
inline void cut(int max_deg) noexcept {
if (deg() > max_deg) this->resize(std::max(0, max_deg + 1));
}
inline int normalize() {
while (this->size() and this->back() == 0) this->pop_back();
return deg();
}
inline FPS pre(int max_deg) const noexcept {
return FPS(this->begin(), this->begin() + std::min(this->deg(), std::max(0, max_deg)) + 1);
}
inline FPS operator+() const {
return FPS(*this);
}
FPS operator-() const {
FPS f(*this);
for (auto &e : f) e = mint::mod() - e;
return f;
}
FPS& operator+=(const FPS &g) {
ensure_deg(g.deg());
for (int i = 0; i <= g.deg(); ++i) unsafe_get(i) += g.unsafe_get(i);
return *this;
}
FPS& operator+=(FPS &&g) {
ensure_deg(g.deg());
for (int i = 0; i <= g.deg(); ++i) unsafe_get(i) += g.unsafe_get(i);
return *this;
}
FPS& operator-=(const FPS &g) {
ensure_deg(g.deg());
for (int i = 0; i <= g.deg(); ++i) unsafe_get(i) -= g.unsafe_get(i);
return *this;
}
FPS& operator-=(FPS &&g) {
ensure_deg(g.deg());
for (int i = 0; i <= g.deg(); ++i) unsafe_get(i) -= g.unsafe_get(i);
return *this;
}
inline FPS& operator*=(const FPS &g) {
(*this) = atcoder::convolution(std::move(*this), g);
return *this;
}
inline FPS& operator*=(FPS &&g) {
(*this) = atcoder::convolution(std::move(*this), std::move(g));
return *this;
}
inline FPS& operator*=(const mint x) {
for (auto &e : *this) e *= x;
return *this;
}
inline FPS operator+(FPS &&g) const { return FPS(*this) += std::move(g); }
inline FPS operator-(FPS &&g) const { return FPS(*this) -= std::move(g); }
inline FPS operator*(FPS &&g) const { return FPS(*this) *= std::move(g); }
inline FPS operator+(const FPS &g) const { return FPS(*this) += g; }
inline FPS operator-(const FPS &g) const { return FPS(*this) -= g; }
inline FPS operator*(const FPS &g) const { return FPS(*this) *= g; }
inline FPS operator*(const mint x) const { return FPS(*this) *= x; }
inline friend FPS operator*(const mint x, const FPS &f) { return f * x; }
inline friend FPS operator*(const mint x, FPS &&f) { return f *= x; }
FPS& inv_inplace(const int max_deg) {
FPS res { unsafe_get(0).inv() };
for (int k = 1; k <= max_deg;) {
k *= 2;
int d = 0;
for (const auto &e : this->pre(k) * (res * res)) {
res[d] = res[d] + res[d] - e;
if (++d > k) break;
}
}
res.cut(max_deg);
(*this) = std::move(res);
return *this;
}
private:
inline void ensure_deg(int d) { if (deg() < d) this->resize(d + 1, 0); }
inline const mint& unsafe_get(int i) const { return std::vector<mint>::operator[](i); }
inline mint& unsafe_get(int i) { return std::vector<mint>::operator[](i); }
};
mint fac[5<<17],invfac[5<<17];
vector<mint> bernoulli(int n) {
FPS a(n + 1);
for (int i = 0; i <= n; ++i) a[i] = invfac[i+1];
a.inv_inplace(n);
for (int i = 2; i <= n; ++i) a[i] *= fac[i];
return a;
}
mint comb(int a,int b){return fac[a]*invfac[b]*invfac[a-b];}
mint lagrange_interpolation(const vector<mint>&y,long long x_)
{
int N=y.size();
if(N==0)return mint::raw(0);
if(x_<N)return y[x_];
vector<mint>L(N),R(N);
mint x=x_;
L[0]=mint::raw(1);
for(int i=1;i<N;i++)L[i]=L[i-1]*(x-mint::raw(i-1));
R[N-1]=mint::raw(1);
for(int i=N-1;i--;)R[i]=R[i+1]*(x-mint::raw(i+1));
mint ret=mint::raw(0);
for(int i=0;i<N;i++)
{
mint now=L[i]*R[i]*invfac[i]*invfac[N-i-1]*y[i];
if(N-i&1)ret+=now;
else ret-=now;
}
return ret;
}
pair<pair<int,int>,pair<int,int> >f(int H,int K)
{//min(i,H-i+1,min(K,H-K+1))
pair<pair<int,int>,pair<int,int> >ret;
int T=min(K,H-K+1);
ret.second.first=T;
ret.second.second=0;
{//i=1..(H+1)//2
int UP=(H+1)/2;
if(T<UP)
{
ret.second.second+=UP-T;
}
ret.first.first=min(UP,T);
}
{//i=(H+1)//2+1..H
int UP=(H+1+1)/2-1;
if(T<UP)
{
ret.second.second+=UP-T;
}
ret.first.second=min(UP,T);
}
return ret;
}
int H,W,N,K;
mint invt;
vector<mint>B;
mint coef[5<<17];
mint S(int k,int n)
{//Sum[i^k,{i,1,n}]
mint ret=mint::raw(0);
mint nn=mint::raw(1);
for(int j=k;j>=0;j--)
{
nn*=mint::raw(n);
ret+=comb(k+1,j)*B[j]*nn;
}
ret/=k+1;
//cout<<"Sum[i^"<<k<<",{i,1,"<<n<<"}] = "<<ret.val()<<endl;
return ret;
}
vector<mint>ff(int UP)
{
FPS f(N+1),g(N+1);
mint t=mint::raw(1);
for(int i=0;i<=N;i++)
{
t*=mint::raw(UP);
f[i]=t*invfac[i+1];
g[i]=invfac[i+1];
}
g.inv_inplace(N);
f*=g;
f=f.pre(N);
vector<mint>ret(N+1);
ret[0]=UP-1;
for(int i=0;i<N;i++)ret[i+1]=f[i+1]*fac[i+1];
return ret;
}
mint g1(const vector<mint>&L,const vector<mint>&R)
{
mint ret=mint::raw(0);
//vector<mint>L=ff(h+1),R=ff(w+1);
for(int n=0;n<=N;n++)ret+=coef[n]*L[n]*R[n];
return ret;
}
mint g2(pair<int,int>h,pair<int,int>w)
{
vector<mint>F(N+5);
F[0]=mint::raw(0);
for(int i=1;i<N+5;i++)
{//(1-i*w.first/t)^N
mint now=(1-mint((long)i*w.first)*invt).pow(N);
F[i]=F[i-1]+now;
}
mint ret=mint::raw(0);
ret+=lagrange_interpolation(F,h.first);
ret+=lagrange_interpolation(F,h.second);
return ret*w.second;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
{
const int n=5<<17;
fac[0]=mint::raw(1);
for(int i=1;i<n;i++)fac[i]=fac[i-1]*mint::raw(i);
invfac[n-1]=fac[n-1].inv();
for(int i=n-1;i--;)invfac[i]=invfac[i+1]*mint::raw(i+1);
}
cin>>H>>W>>N>>K;
//B=bernoulli(N);
//B[1]=-B[1];
invt=mint((long)(H-K+1)*(W-K+1)).inv();
{
mint c=mint::raw(1);
for(int i=0;i<=N;i++)
{
coef[i]=c*comb(N,i);
c*=-invt;
}
}
pair<pair<int,int>,pair<int,int> >h=f(H,K),w=f(W,K);
mint ans=(long)H*W;
vector<mint>L1=ff(h.first.first+1),L2=ff(h.first.second+1);
vector<mint>R1=ff(w.first.first+1),R2=ff(w.first.second+1);
ans-=g1(L1,R1);
ans-=g1(L1,R2);
ans-=g1(L2,R1);
ans-=g1(L2,R2);
/*
ans-=g1(h.first.first,w.first.first);
ans-=g1(h.first.first,w.first.second);
ans-=g1(h.first.second,w.first.first);
ans-=g1(h.first.second,w.first.second);
*/
ans-=g2(h.first,w.second);
ans-=g2(w.first,h.second);
if(h.second.second>0&&w.second.second>0)
{
ans-=(1-mint((long)h.second.first*w.second.first)*invt).pow(N)*h.second.second*w.second.second;
}
cout<<ans.val()<<endl;
}