結果

問題 No.502 階乗を計算するだけ
ユーザー snrnsidysnrnsidy
提出日時 2021-02-28 19:38:39
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 94 ms / 1,000 ms
コード長 4,488 bytes
コンパイル時間 2,049 ms
コンパイル使用メモリ 208,168 KB
実行使用メモリ 15,208 KB
最終ジャッジ日時 2024-04-12 10:13:26
合計ジャッジ時間 4,100 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
6,812 KB
testcase_01 AC 2 ms
6,940 KB
testcase_02 AC 1 ms
6,944 KB
testcase_03 AC 2 ms
6,940 KB
testcase_04 AC 2 ms
6,940 KB
testcase_05 AC 2 ms
6,940 KB
testcase_06 AC 1 ms
6,940 KB
testcase_07 AC 1 ms
6,944 KB
testcase_08 AC 1 ms
6,944 KB
testcase_09 AC 2 ms
6,944 KB
testcase_10 AC 2 ms
6,944 KB
testcase_11 AC 2 ms
6,944 KB
testcase_12 AC 1 ms
6,940 KB
testcase_13 AC 2 ms
6,940 KB
testcase_14 AC 1 ms
6,940 KB
testcase_15 AC 1 ms
6,944 KB
testcase_16 AC 2 ms
6,940 KB
testcase_17 AC 2 ms
6,944 KB
testcase_18 AC 2 ms
6,944 KB
testcase_19 AC 1 ms
6,944 KB
testcase_20 AC 2 ms
6,940 KB
testcase_21 AC 2 ms
6,940 KB
testcase_22 AC 4 ms
6,940 KB
testcase_23 AC 3 ms
6,944 KB
testcase_24 AC 4 ms
6,944 KB
testcase_25 AC 2 ms
6,944 KB
testcase_26 AC 4 ms
6,940 KB
testcase_27 AC 3 ms
6,940 KB
testcase_28 AC 4 ms
6,944 KB
testcase_29 AC 2 ms
6,944 KB
testcase_30 AC 4 ms
6,940 KB
testcase_31 AC 4 ms
6,944 KB
testcase_32 AC 92 ms
15,084 KB
testcase_33 AC 93 ms
15,136 KB
testcase_34 AC 92 ms
15,080 KB
testcase_35 AC 45 ms
9,316 KB
testcase_36 AC 92 ms
15,084 KB
testcase_37 AC 92 ms
15,204 KB
testcase_38 AC 94 ms
15,204 KB
testcase_39 AC 94 ms
15,208 KB
testcase_40 AC 45 ms
9,196 KB
testcase_41 AC 92 ms
15,108 KB
testcase_42 AC 2 ms
6,940 KB
testcase_43 AC 2 ms
6,940 KB
testcase_44 AC 2 ms
6,940 KB
testcase_45 AC 2 ms
6,940 KB
testcase_46 AC 2 ms
6,940 KB
testcase_47 AC 2 ms
6,944 KB
testcase_48 AC 1 ms
6,940 KB
testcase_49 AC 1 ms
6,940 KB
testcase_50 AC 1 ms
6,940 KB
testcase_51 AC 1 ms
6,944 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<bits/stdc++.h>
using namespace std;

using i64 = int64_t;
using u64 = uint64_t;
using base = complex<double>;

void fft(vector<base> &a, bool inv){
	int n = a.size(), j = 0;
	for(int i=1; i<n; i++){
		int bit = (n >> 1);
		while(j >= bit){
			j -= bit;
			bit >>= 1;
		}
		j += bit;
		if(i < j) swap(a[i], a[j]);
	}
	double ang = 2 * acos(-1) / n * (inv ? -1 : 1);
	
	vector<base> roots(n/2);

	for(int i=0; i<n/2; i++)
		roots[i] = base(cos(ang * i), sin(ang * i));

	for(int i=2; i<=n; i<<=1){
		int step = n / i;
		for(int j=0; j<n; j+=i){
			for(int k=0; k<i/2; k++){
				base u = a[j+k], v = a[j+k+i/2] * roots[step * k];
				a[j+k] = u+v;
				a[j+k+i/2] = u-v;
			}
		}
	}
	if(inv) for(int i=0; i<n; i++) a[i] /= n; 
}

vector<i64> multiply(vector<i64> &v, vector<i64> &w, i64 mod){
	int n = 2; while(n < v.size() + w.size()) n <<= 1;
	vector<base> v1(n), v2(n), r1(n), r2(n);
	for(int i=0; i<v.size(); i++)
		v1[i] = base(v[i] >> 15, v[i] & 32767);
	for(int i=0; i<w.size(); i++)
		v2[i] = base(w[i] >> 15, w[i] & 32767);
	fft(v1, 0);
	fft(v2, 0);
	for(int i=0; i<n; i++){
		int j = (i ? (n - i) : i);
		base ans1 = (v1[i] + conj(v1[j])) * base(0.5, 0);
		base ans2 = (v1[i] - conj(v1[j])) * base(0, -0.5);
		base ans3 = (v2[i] + conj(v2[j])) * base(0.5, 0);
		base ans4 = (v2[i] - conj(v2[j])) * base(0, -0.5);
		r1[i] = (ans1 * ans3) + (ans1 * ans4) * base(0, 1);
		r2[i] = (ans2 * ans3) + (ans2 * ans4) * base(0, 1);
	}
	fft(r1, 1);
	fft(r2, 1);
	vector<i64> ret(n);
	for(int i=0; i<n; i++){
		i64 av = (i64)round(r1[i].real());
		i64 bv = (i64)round(r1[i].imag()) + (i64)round(r2[i].real());
		i64 cv = (i64)round(r2[i].imag());
		av %= mod, bv %= mod, cv %= mod;
		ret[i] = (av << 30) + (bv << 15) + cv;
		ret[i] %= mod;
		ret[i] += mod;
		ret[i] %= mod;
	}
	return ret;
}

/*
Input h: vector of length d+1: h(0), h(1), h(2), ..., h(d) where h is d-degree polynomial
Input P: prime, where h degree polynomial is represented as modulo 
Output: vector of length 4d+2: h(0), h(1), h(2), ..., h(4d+1)

Lagrange interpolation is given as follows:

h(x) = sum_(i=0 to d)[ h(i) prod(j=0 to d, i != j){(x-j)/(i-j)} ]
     = {prod(j=0 to d)(x-j)} * 
	   sum(i=0 to d){h(i)/(i!(d-i)!(-1)^(d-i)) * 1/(x-i) }

Letting f(i) = h(i)/(i!(d-i)!(-1)^(d-i)), g(x) = prod(j=0 to d)(x-j)
h(x) = g(x) * sum(i=0 to d){f(i) * 1/(x-i)} which is easily calculated by convolution.

f(0)...f(d) and 1/1, 1/2, ..., 1/(4d+1) will be fed to multiply logic

*/
vector<i64> lagrange(vector<i64> h, i64 P) {
	int d = (int)h.size()-1; assert(d >= 0);
	auto mul = [P](i64 a, i64 b){
		return a*b%P;
	};
	
	auto ipow = [&mul](i64 a, i64 b) {
		i64 res = 1;
		while(b) {
			if(b&1) res = mul(res, a);
			a = mul(a, a);
			b >>= 1;
		}
		return res;
	};
	
	auto modInv = [&ipow, P](i64 a) {
		return ipow(a, P-2);
	};
	
	vector<i64> fact(4*d+2), invfact(4*d+2);
	fact[0] = 1;
	for(int i=1; i<=4*d+1; ++i)
		fact[i] = mul(i,fact[i-1]);
	
	invfact[4*d+1] = modInv(fact[4*d+1]);
	for(int i=4*d; i>=0; --i)
		invfact[i] = mul(invfact[i+1], i+1);
	
	vector<i64> f(d+1);
	for(int i=0; i<=d; ++i) {
		f[i] = h[i];
		f[i] = mul(invfact[i], f[i]);
		f[i] = mul(invfact[d-i], f[i]);
		if((d-i)%2==1) f[i] = P-f[i];
		if(f[i] == P) f[i] = 0;
	}
	
	vector<i64> inv(4*d+2);
	for(int i=1; i<4*d+2; ++i)
		inv[i] = mul(fact[i-1], invfact[i]);
	
	vector<i64> g(4*d+2);
	g[d+1] = 1;
	for(int j=0; j<=d; ++j) g[d+1] = mul(g[d+1], d+1-j);
	for(int i=d+2; i<4*d+2; ++i)
		g[i] = mul(g[i-1], mul(i, inv[i-d-1]));

	vector<i64> conv = multiply(f, inv, P);

	vector<i64> ret(4*d+2);
	for(int i=0; i<=d; ++i) ret[i] = h[i];
	for(int i=d+1; i<4*d+2; ++i) ret[i] = mul(g[i], conv[i]);

	return ret;
}
vector<i64> squarepoly(vector<i64> poly, i64 P) {
	vector<i64> ss = lagrange(poly, P);
	vector<i64> ret(ss.size()/2);
	auto mul = [P](i64 a, i64 b){
		return a*b%P;
	};
	for(int i=0; i<(int)ss.size()/2; ++i)
		ret[i] = mul(ss[2*i], ss[2*i+1]);
	return ret;
}

i64 factorial(i64 N,i64 P)
{
  	i64 d = 1;
	vector<i64> fact_part = {1%P, 2%P};
	while(N > d*(d+1)) {
		fact_part = squarepoly(fact_part, P);
		d *= 2;
	}
  
	auto mul = [P](i64 a, i64 b) {
		return a*b%P;
	};
	
	i64 ans = 1, bucket = N/d;
	for(int i=0; i<bucket; ++i) ans = mul(ans, fact_part[i]);
	for(i64 i=bucket*d+1; i<=N; ++i) ans = mul(ans, i);
	
	return ans;
}

int main() 
{
	cin.tie(0);
	ios::sync_with_stdio(false);

	i64 N, P; 
	cin >> N;
	P = 1e9 + 7;
	
	if(N>=P)
	{
	    cout << 0 << '\n';
	    return 0;
	}

  	cout << factorial(N,P) << '\n';

	return 0;
}
0