結果

問題 No.2613 Sum of Combination
ユーザー shobonvipshobonvip
提出日時 2024-01-25 17:19:36
言語 C++17(gcc12)
(gcc 12.3.0 + boost 1.87.0)
結果
AC  
実行時間 264 ms / 4,500 ms
コード長 2,288 bytes
コンパイル時間 4,060 ms
コンパイル使用メモリ 251,284 KB
実行使用メモリ 17,640 KB
最終ジャッジ日時 2024-09-28 07:22:20
合計ジャッジ時間 10,642 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 49
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<bits/stdc++.h>
#include<atcoder/modint>
#include<atcoder/math>
#include<atcoder/convolution>
using namespace std;
typedef long long ll;
typedef atcoder::modint998244353 mint;

vector<int> pfact(int n){
	vector<int> ret;
	for (int i=2; i*i<=n; i++){
		if (n % i == 0){
			ret.push_back(i);
			while (n % i == 0){
				n /= i;
			}
		}
	}
	if (n > 1) ret.push_back(n);
	return ret;
}

ll modpow(ll n, ll m, ll p){
	ll ret = 1;
	ll tmp = n;
	while(m > 0){
		if (m & 1){
			ret *= tmp;
			ret %= p;
		}
		tmp *= tmp;
		tmp %= p;
		m >>= 1;
	}
	return ret;
}

bool is_p_root(int x, int p, vector<int> &v){
	for (int m: v){
		if (modpow(x, (p-1)/m, p) == 1) return false;
	}
	return true;
}

int findpr(int p){
	random_device seed_gen;
	mt19937 engine(seed_gen());
	uniform_int_distribution<int> dist(1, p-1);
	vector<int> v = pfact(p-1);
	int x;
	do{
		x = dist(engine);
	}while(!is_p_root(x, p, v));
	return x;
}

int main(){
	ll n; cin >> n;
	int p; cin >> p;

	ll mx = p-1;
	vector<ll> fact(mx + 1, 1);
	vector<ll> factinv(mx + 1, 1);
	fact[0] = 1;
	for (int i=1; i<=mx; i++){
		fact[i] = fact[i-1] * i % p;
	}
	factinv[mx] = modpow(fact[mx], p-2, p);
	for (int i=mx; i>=1; i--){
		factinv[i-1] = factinv[i] * i % p;
	}

	ll g = findpr(p);
	vector<ll> taio(p, -1);
	vector<ll> fuku(p-1, 0);

	{
		ll tmp = 1;
		for (int i=0; i<p-1; i++){
			taio[tmp] = i;
			fuku[i] = tmp;
			tmp *= g;
			tmp %= p;
		}
	}


	vector<ll> a;
	while(n > 0){
		a.push_back(n % p);
		n /= p;
	}
	reverse(a.begin(), a.end());

	auto cmb = [&](ll n, ll r) -> ll {
		if (n < r) return 0;
		return fact[n] * factinv[r] % p * factinv[n-r] % p;
	};

	vector<mint> dp(p-1);
	for (int num=0; num<(int)a.size(); num++){
		vector<mint> ndp(p-1);
		vector<mint> g(p-1);
		for (int i=0; i<=a[num]; i++){
			ll k = taio[cmb(a[num], i)];
			assert(k != -1);
			g[k] += 1;
		}
		
		vector<mint> h = atcoder::convolution(g, dp);
		for (int i=0; i<(int)h.size(); i++){
			ndp[i%(p-1)] += h[i];
		}
		
		if (num > 0){
			for (int i=1; i<=a[num]; i++){
				ndp[taio[cmb(a[num], i)]] += 1;
			}
		}

		for (int i=(num<1); i<a[num]; i++){
			ndp[taio[cmb(a[num], i)]%(p-1)] += 1;
		}

		dp = ndp;
	}

	dp[0] += 2;

	mint ans = 0;
	for (int i=0; i<p-1; i++){
		ans += mint(dp[i]) * fuku[i];
	}
	cout << ans.val() << '\n';	
}
0