結果

問題 No.215 素数サイコロと合成数サイコロ (3-Hard)
ユーザー piyoko_212piyoko_212
提出日時 2015-12-27 00:20:42
言語 C++11
(gcc 11.4.0)
結果
MLE  
実行時間 -
コード長 5,613 bytes
コンパイル時間 1,031 ms
コンパイル使用メモリ 60,864 KB
実行使用メモリ 82,856 KB
最終ジャッジ日時 2024-09-19 07:13:25
合計ジャッジ時間 11,931 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 MLE -
testcase_01 MLE -
権限があれば一括ダウンロードができます
コンパイルメッセージ
main.cpp: In function ‘int main()’:
main.cpp:158:14: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
  158 |         scanf("%lld%d%d",&n,&a,&b);
      |         ~~~~~^~~~~~~~~~~~~~~~~~~~~

ソースコード

diff #

#include<stdio.h>
#include<algorithm>
#include<vector>
using namespace std;
namespace NTT{
	long long extgcd(long long a,long long b,long long&x,long long&y){
		for(long long u=y=1,v=x=0;a;){
			long long q=b/a;swap(x-=q*u,u);swap(y-=q*v,v);swap(b-=q*a,a);
		}
		return b;
	}
	long long mod_inv(long long a,long long m){
		long long x,y;
		extgcd(a,m,x,y);
		return (m+x%m)%m;
	}
	long long mod_pow(long long a,long long b,long long m){
		long long ret=1;
		while(b){
			if(b%2)ret=ret*a%m;
			b/=2;
			a=a*a%m;
		}
		return ret;
	}
	template<int mod,int primitive_root> class NTT{
		public:
		int get_mod()const{return mod;}
		void _ntt(vector<long long>&a,int sign){
			const int n=a.size();
			const int g=3;
			int h=(int)mod_pow(g,(mod-1)/n,mod);
			if(sign==-1)h=(int)mod_inv(h,mod);
			int i=0;
			for(int j=1;j<n-1;j++){
				for(int k=n>>1;k>(i^=k);k>>=1);
				if(j<i)swap(a[i],a[j]);
			}
			for(int m=1;m<n;m*=2){
				const int m2=2*m;
				const long long base=mod_pow(h,n/m2,mod);
				long long w=1;
				for(int x=0;x<m;x++){
					for(int s=x;s<n;s+=m2){
						long long u=a[s];
						long long d=a[s+m]*w%mod;
						a[s]=u+d;
						if(a[s]>=mod)a[s]-=mod;
						a[s+m]=u-d;
						if(a[s+m]<0)a[s+m]+=mod;
					}
					w=w*base%mod;
				}
			}
			for(int i=0;i<a.size();i++)if(a[i]<0)a[i]+=mod;
		}
		void ntt(vector<long long>&input){
			_ntt(input,1);
		}
		void intt(vector<long long>&input){
			_ntt(input,-1);
			const int n_inv=mod_inv(input.size(),mod);
			for(int i=0;i<input.size();i++)input[i]=input[i]*n_inv%mod;
		}
		vector<long long>convolution(const vector<long long>&a,const vector<long long>&b){
			int ntt_size=1;
			while(ntt_size<a.size()+b.size())ntt_size*=2;
			vector<long long>_a=a,_b=b;
			_a.resize(ntt_size);
			_b.resize(ntt_size);
			ntt(_a);ntt(_b);
			for(int i=0;i<ntt_size;i++){
				(_a[i]*=_b[i])%=mod;
			}
			intt(_a);
			return _a;
		}
	};
	long long garner(vector<pair<long long,long long> > mr,long long mod){
		mr.push_back(make_pair(mod,0));
		vector<long long>coffs(mr.size(),1);
		vector<long long>constants(mr.size(),0);
		for(int i=0;i<mr.size()-1;i++){
			long long v=(mr[i].second-constants[i])*mod_inv(coffs[i],mr[i].first)%mr[i].first;
			if(v<0)v+=mr[i].first;
			for(int j=i+1;j<mr.size();j++){
				(constants[j]+=coffs[j]*v)%=mr[j].first;
				(coffs[j]*=mr[i].first)%=mr[j].first;
			}
		}
		return constants[mr.size()-1];
	}
	typedef NTT<167772161,3> NTT_1;
	typedef NTT<469762049,3> NTT_2;
	typedef NTT<1224736769,3> NTT_3;
	vector<long long>int32mod_convolution(vector<long long>a,vector<long long>b,int mod){
		for(int i=0;i<a.size();i++)a[i]%=mod;
		for(int i=0;i<b.size();i++)b[i]%=mod;
		NTT_1 ntt1;
		NTT_2 ntt2;
		NTT_3 ntt3;
		vector<long long>x=ntt1.convolution(a,b);
		vector<long long>y=ntt2.convolution(a,b);
		vector<long long>z=ntt3.convolution(a,b);
		vector<long long>ret(x.size());
		vector<pair<long long,long long> >mr(3);
		for(int i=0;i<x.size();i++){
			mr[0].first=ntt1.get_mod(),mr[0].second=x[i];
			mr[1].first=ntt2.get_mod(),mr[1].second=y[i];
			mr[2].first=ntt3.get_mod(),mr[2].second=z[i];
			ret[i]=garner(mr,mod);
		}
		return ret;
	}
}
long long mod=1000000007;
int s[]={2,3,5,7,11,13};
int t[]={4,6,8,9,10,12};
long long U[8000];
long long S[4000];
long long T[4000];
long long dp[8][310][4000];
long long ul=8000;
vector<long long>g;
vector<long long>h;
vector<long long> calc(long long t){
//	printf("%lld\n",t);
	if(t<ul){
		vector<long long>ret(8000);
		ret[t]=1;
		return ret;
	}
	vector<long long>chi=calc(t/2);
	vector<long long>na(8000);
	vector<long long>nb(8001);
	for(int i=0;i<8000;i++){
		na[i]=nb[i+t%2]=chi[i];
	}
	vector<long long>f=NTT::int32mod_convolution(na,nb,mod);
	f.resize(16010);
	vector<long long>_f=f;
	reverse(_f.begin(),_f.end());
	vector<long long>q=NTT::int32mod_convolution(_f,h,mod);
	q.resize(8010);
	reverse(q.begin(),q.end());
	vector<long long>r=NTT::int32mod_convolution(q,g,mod);
//	printf("%d %d\n",r.size(),f.size());
	for(int i=0;i<r.size();i++)r[i]=(mod-r[i])%mod;
	for(int i=0;i<f.size();i++)r[i]=(r[i]+f[i])%mod;
	r.resize(8000);
//	printf("%lld: \n",t);
//	for(int i=0;i<r.size();i++)if(r[i])printf("%d: %lld\n",i,r[i]);
	return r;
}
int main(){
	long long n;
	int a,b;
	scanf("%lld%d%d",&n,&a,&b);
	dp[0][0][0]=1;
	for(int i=0;i<6;i++){
		for(int j=0;j<=a;j++){
			for(int k=0;k<3950;k++){
				if(dp[i][j][k]==0)continue;
				dp[i+1][j][k]=(dp[i+1][j][k]+dp[i][j][k])%mod;
				if(j<a)dp[i][j+1][k+s[i]]=(dp[i][j+1][k+s[i]]+dp[i][j][k])%mod;
			}
		}
	}
	for(int i=0;i<4000;i++)S[i]=dp[6][a][i];
	for(int i=0;i<8;i++)for(int j=0;j<310;j++)for(int k=0;k<4000;k++)dp[i][j][k]=0;
	dp[0][0][0]=1;
	for(int i=0;i<6;i++){
		for(int j=0;j<=b;j++){
			for(int k=0;k<3950;k++){
				if(dp[i][j][k]==0)continue;
				dp[i+1][j][k]=(dp[i+1][j][k]+dp[i][j][k])%mod;
				if(j<b)dp[i][j+1][k+t[i]]=(dp[i][j+1][k+t[i]]+dp[i][j][k])%mod;
			}
		}
	}
	for(int i=0;i<4000;i++)T[i]=dp[6][b][i];
	for(int i=0;i<4000;i++)for(int j=0;j<4000;j++){
		U[i+j]=(U[i+j]+S[i]*T[j])%mod;
	}
	g.resize(8001);
	g[8000]=1;
	for(int i=1;i<8000;i++){
		g[8000-i]=(mod-U[i])%mod;
	}
	vector<long long>_g=g;
	reverse(_g.begin(),_g.end());
	h.push_back(1);
	int L=1;
	while(L<8010){
		vector<long long>h2=NTT::int32mod_convolution(h,h,mod);
		vector<long long>tg=_g;
		tg.resize(L*2);
		h2=NTT::int32mod_convolution(h2,tg,mod);
		h2.resize(L*2);
		for(int i=0;i<L*2;i++)h2[i]=(mod-h2[i])%mod;
		for(int i=0;i<L;i++)h2[i]=(h2[i]+h[i]*2)%mod;
		h=h2;
		L*=2;
	}
	h.resize(8010);
	vector<long long> ans=calc(n+7999);
	long long ret=0;
	for(int i=0;i<8000;i++){
		ret=(ret+ans[i])%mod;
	}
	printf("%lld\n",ret);
}
0