結果

問題 No.2459 Stampaholic (Hard)
ユーザー kotatsugamekotatsugame
提出日時 2023-09-01 23:57:44
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
TLE  
実行時間 -
コード長 7,769 bytes
コンパイル時間 2,874 ms
コンパイル使用メモリ 123,364 KB
実行使用メモリ 39,248 KB
最終ジャッジ日時 2024-06-11 06:18:10
合計ジャッジ時間 49,515 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 10 ms
11,204 KB
testcase_01 AC 3,855 ms
38,876 KB
testcase_02 AC 902 ms
16,792 KB
testcase_03 AC 10 ms
11,120 KB
testcase_04 AC 10 ms
11,208 KB
testcase_05 AC 10 ms
11,072 KB
testcase_06 AC 10 ms
11,192 KB
testcase_07 AC 10 ms
11,204 KB
testcase_08 AC 1,869 ms
22,140 KB
testcase_09 AC 896 ms
17,392 KB
testcase_10 AC 3,829 ms
36,028 KB
testcase_11 AC 1,891 ms
24,324 KB
testcase_12 AC 3,889 ms
38,332 KB
testcase_13 AC 3,796 ms
35,884 KB
testcase_14 AC 913 ms
17,992 KB
testcase_15 AC 3,889 ms
38,872 KB
testcase_16 TLE -
testcase_17 TLE -
testcase_18 TLE -
testcase_19 AC 3,966 ms
38,868 KB
testcase_20 AC 10 ms
11,192 KB
testcase_21 AC 3,790 ms
33,664 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<iostream>
#include<vector>
#include<cassert>
#include<atcoder/modint>
#include<atcoder/convolution>
using namespace std;
using mint=atcoder::modint998244353;
class FPS : public std::vector<mint> {
    public:
        using std::vector<mint>::vector;

        FPS(const std::initializer_list<mint> l) : std::vector<mint>::vector(l) {}

        inline FPS& operator=(const std::vector<mint> &&f) & noexcept {
            std::vector<mint>::operator=(std::move(f));
            return *this;
        }
        inline FPS& operator=(const std::vector<mint>  &f) & {
            std::vector<mint>::operator=(f);
            return *this;
        }

        inline const mint operator[](int n) const noexcept {
            return n <= deg() ? unsafe_get(n) : 0;
        }
        inline mint& operator[](int n) noexcept {
            ensure_deg(n);
            return unsafe_get(n);
        }

        inline int size() const noexcept { return std::vector<mint>::size(); }
        inline int deg()  const noexcept { return int(this->size()) - 1; }
        inline void cut(int max_deg) noexcept {
            if (deg() > max_deg) this->resize(std::max(0, max_deg + 1));
        }
        inline int normalize() {
            while (this->size() and this->back() == 0) this->pop_back();
            return deg();
        }
        inline FPS pre(int max_deg) const noexcept {
            return FPS(this->begin(), this->begin() + std::min(this->deg(), std::max(0, max_deg)) + 1);
        }

        inline FPS operator+() const {
            return FPS(*this);
        }
        FPS operator-() const {
            FPS f(*this);
            for (auto &e : f) e = mint::mod() - e;
            return f;
        }
        FPS& operator+=(const FPS &g) {
            ensure_deg(g.deg());
            for (int i = 0; i <= g.deg(); ++i) unsafe_get(i) += g.unsafe_get(i);
            return *this;
        }
        FPS& operator+=(FPS &&g) {
            ensure_deg(g.deg());
            for (int i = 0; i <= g.deg(); ++i) unsafe_get(i) += g.unsafe_get(i);
            return *this;
        }
        FPS& operator-=(const FPS &g) {
            ensure_deg(g.deg());
            for (int i = 0; i <= g.deg(); ++i) unsafe_get(i) -= g.unsafe_get(i);
            return *this;
        }
        FPS& operator-=(FPS &&g) {
            ensure_deg(g.deg());
            for (int i = 0; i <= g.deg(); ++i) unsafe_get(i) -= g.unsafe_get(i);
            return *this;
        }
        inline FPS& operator*=(const FPS &g) {
            (*this) = atcoder::convolution(std::move(*this), g);
            return *this;
        }
        inline FPS& operator*=(FPS &&g) {
            (*this) = atcoder::convolution(std::move(*this), std::move(g));
            return *this;
        }
        inline FPS& operator*=(const mint x) {
            for (auto &e : *this) e *= x;
            return *this;
        }

        inline FPS operator+(FPS &&g) const { return FPS(*this) += std::move(g); }
        inline FPS operator-(FPS &&g) const { return FPS(*this) -= std::move(g); }
        inline FPS operator*(FPS &&g) const { return FPS(*this) *= std::move(g); }
        inline FPS operator+(const FPS &g) const { return FPS(*this) += g; }
        inline FPS operator-(const FPS &g) const { return FPS(*this) -= g; }
        inline FPS operator*(const FPS &g) const { return FPS(*this) *= g; }
        inline FPS operator*(const mint x) const { return FPS(*this) *= x; }
        inline friend FPS operator*(const mint x, const FPS  &f) { return f * x; }
        inline friend FPS operator*(const mint x,       FPS &&f) { return f *= x; }

        FPS& inv_inplace(const int max_deg) {
            FPS res { unsafe_get(0).inv() };
            for (int k = 1; k <= max_deg;) {
                k *= 2;
                int d = 0;
                for (const auto &e : this->pre(k) * (res * res)) {
                    res[d] = res[d] + res[d] - e;
                    if (++d > k) break;
                }
            }
            res.cut(max_deg);
            (*this) = std::move(res);
            return *this;
        }

    private:
        inline void ensure_deg(int d) { if (deg() < d) this->resize(d + 1, 0); }
        inline const mint& unsafe_get(int i) const { return std::vector<mint>::operator[](i); }
        inline       mint& unsafe_get(int i)       { return std::vector<mint>::operator[](i); }
};
mint fac[5<<17],invfac[5<<17];
vector<mint> bernoulli(int n) {
    FPS a(n + 1);
    for (int i = 0; i <= n; ++i) a[i] = invfac[i+1];
    a.inv_inplace(n);
    for (int i = 2; i <= n; ++i) a[i] *= fac[i];
    return a;
}
mint comb(int a,int b){return fac[a]*invfac[b]*invfac[a-b];}
mint lagrange_interpolation(const vector<mint>&y,long long x_)
{
	int N=y.size();
	if(N==0)return mint::raw(0);
	if(x_<N)return y[x_];
	vector<mint>L(N),R(N);
	mint x=x_;
	L[0]=mint::raw(1);
	for(int i=1;i<N;i++)L[i]=L[i-1]*(x-mint::raw(i-1));
	R[N-1]=mint::raw(1);
	for(int i=N-1;i--;)R[i]=R[i+1]*(x-mint::raw(i+1));
	mint ret=mint::raw(0);
	for(int i=0;i<N;i++)
	{
		mint now=L[i]*R[i]*invfac[i]*invfac[N-i-1]*y[i];
		if(N-i&1)ret+=now;
		else ret-=now;
	}
	return ret;
}
pair<pair<int,int>,pair<int,int> >f(int H,int K)
{//min(i,H-i+1,min(K,H-K+1))
	pair<pair<int,int>,pair<int,int> >ret;
	int T=min(K,H-K+1);
	ret.second.first=T;
	ret.second.second=0;
	{//i=1..(H+1)//2
		int UP=(H+1)/2;
		if(T<UP)
		{
			ret.second.second+=UP-T;
		}
		ret.first.first=min(UP,T);
	}
	{//i=(H+1)//2+1..H
		int UP=(H+1+1)/2-1;
		if(T<UP)
		{
			ret.second.second+=UP-T;
		}
		ret.first.second=min(UP,T);
	}
	return ret;
}
int H,W,N,K;
mint invt;
vector<mint>B;
mint coef[5<<17];
mint S(int k,int n)
{//Sum[i^k,{i,1,n}]
	mint ret=mint::raw(0);
	mint nn=mint::raw(1);
	for(int j=k;j>=0;j--)
	{
		nn*=mint::raw(n);
		ret+=comb(k+1,j)*B[j]*nn;
	}
	ret/=k+1;
	//cout<<"Sum[i^"<<k<<",{i,1,"<<n<<"}] = "<<ret.val()<<endl;
	return ret;
}
vector<mint>ff(int UP)
{
	FPS f(N+1),g(N+1);
	mint t=mint::raw(1);
	for(int i=0;i<=N;i++)
	{
		t*=mint::raw(UP);
		f[i]=t*invfac[i+1];
		g[i]=invfac[i+1];
	}
	g.inv_inplace(N);
	f*=g;
	f=f.pre(N);
	vector<mint>ret(N+1);
	ret[0]=UP-1;
	for(int i=0;i<N;i++)ret[i+1]=f[i+1]*fac[i+1];
	return ret;
}
mint g1(int h,int w)
{
	mint ret=mint::raw(0);
	vector<mint>L=ff(h+1),R=ff(w+1);
	for(int n=0;n<=N;n++)
	{
		ret+=coef[n]*L[n]*R[n];
		//S(n,h)*S(n,w);
		//for(int i=1;i<=h;i++)for(int j=1;j<=w;j++)ret+=coef[n]*mint::raw(i*j).pow(n);
	}
	/*
	mint v=0;
	for(int i=1;i<=h;i++)for(int j=1;j<=w;j++)v+=(1-mint::raw((long)i*j)*invt).pow(N);
	*/
	return ret;
}
mint g2(pair<int,int>h,pair<int,int>w)
{
	vector<mint>F(N+5);
	F[0]=mint::raw(0);
	for(int i=1;i<N+5;i++)
	{//(1-i*w.first/t)^N
		mint now=(1-mint((long)i*w.first)*invt).pow(N);
		F[i]=F[i-1]+now;
	}
	mint ret=mint::raw(0);
	ret+=lagrange_interpolation(F,h.first);
	ret+=lagrange_interpolation(F,h.second);
	return ret*w.second;
}
int main()
{
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	{
		const int n=5<<17;
		fac[0]=mint::raw(1);
		for(int i=1;i<n;i++)fac[i]=fac[i-1]*mint::raw(i);
		invfac[n-1]=fac[n-1].inv();
		for(int i=n-1;i--;)invfac[i]=invfac[i+1]*mint::raw(i+1);
	}
	cin>>H>>W>>N>>K;
	B=bernoulli(N);
	B[1]=-B[1];
	invt=mint((long)(H-K+1)*(W-K+1)).inv();
	{
		mint c=mint::raw(1);
		for(int i=0;i<=N;i++)
		{
			coef[i]=c*comb(N,i);
			c*=-invt;
		}
	}
	pair<pair<int,int>,pair<int,int> >h=f(H,K),w=f(W,K);
	mint ans=(long)H*W;
	ans-=g1(h.first.first,w.first.first);
	ans-=g1(h.first.first,w.first.second);
	ans-=g1(h.first.second,w.first.first);
	ans-=g1(h.first.second,w.first.second);
	ans-=g2(h.first,w.second);
	ans-=g2(w.first,h.second);
	if(h.second.second>0&&w.second.second>0)
	{
		ans-=(1-mint((long)h.second.first*w.second.first)*invt).pow(N)*h.second.second*w.second.second;
	}
	cout<<ans.val()<<endl;
}
0