結果

問題 No.1191 数え上げを愛したい(数列編)
ユーザー sanada_atcodersanada_atcoder
提出日時 2020-07-31 22:13:23
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 78 ms / 2,000 ms
コード長 3,752 bytes
コンパイル時間 1,007 ms
コンパイル使用メモリ 102,192 KB
実行使用メモリ 58,208 KB
最終ジャッジ日時 2024-07-06 18:51:27
合計ジャッジ時間 3,496 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 67 ms
58,020 KB
testcase_01 AC 72 ms
57,984 KB
testcase_02 AC 72 ms
58,104 KB
testcase_03 AC 74 ms
57,952 KB
testcase_04 AC 70 ms
58,112 KB
testcase_05 AC 67 ms
58,112 KB
testcase_06 AC 78 ms
58,020 KB
testcase_07 AC 72 ms
58,056 KB
testcase_08 AC 73 ms
58,016 KB
testcase_09 AC 71 ms
57,984 KB
testcase_10 AC 72 ms
58,112 KB
testcase_11 AC 71 ms
58,116 KB
testcase_12 AC 72 ms
58,112 KB
testcase_13 AC 71 ms
58,144 KB
testcase_14 AC 72 ms
58,024 KB
testcase_15 AC 69 ms
58,112 KB
testcase_16 AC 69 ms
58,076 KB
testcase_17 AC 73 ms
58,112 KB
testcase_18 AC 72 ms
58,080 KB
testcase_19 AC 72 ms
58,112 KB
testcase_20 AC 72 ms
58,016 KB
testcase_21 AC 65 ms
58,208 KB
testcase_22 AC 71 ms
58,112 KB
testcase_23 AC 69 ms
57,904 KB
testcase_24 AC 72 ms
58,112 KB
testcase_25 AC 67 ms
58,064 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<iostream>
#include<algorithm>
#include<string>
#include<vector>
#include<cmath>
#include<map>
#include<random>
#include<iomanip>
#include<queue>
#include<stack>
#include<assert.h>
#include<time.h>
#define int long long
#define double long double
#define rep(i,n) for(int i=0;i<n;i++)
#define REP(i,n) for(int i=1;i<=n;i++)
#define ggr getchar();getchar();return 0;
#define prique priority_queue
constexpr auto mod = 998244353;
#define inf 1e15
#define key 1e9
using namespace std;
typedef pair<int, int>P;
template<class T> inline void chmax(T& a, T b) {
	a = std::max(a, b);
}
template<class T> inline void chmin(T& a, T b) {
	a = std::min(a, b);
}
//combination(Nが小さい時はこれを使う
const int MAX = 2330000;
int fac[MAX], finv[MAX], inv[MAX];
// テーブルを作る前処理
void COMinit() {
	fac[0] = fac[1] = 1;
	finv[0] = finv[1] = 1;
	inv[1] = 1;
	for (int i = 2; i < MAX; i++) {
		fac[i] = fac[i - 1] * i % mod;
		inv[i] = mod - inv[mod % i] * (mod / i) % mod;
		finv[i] = finv[i - 1] * inv[i] % mod;
	}
}
int COMB(int n, int k) {
	if (n < k) return 0;
	if (n < 0 || k < 0) return 0;
	return fac[n] * (finv[k] * finv[n - k] % mod) % mod;
}
bool prime(int n) {
	int cnt = 0;
	for (int i = 1; i <= sqrt(n); i++) {
		if (n % i == 0)cnt++;
	}
	if (cnt != 1)return false;
	else return n != 1;
}
int gcd(int x, int y) {
	if (y == 0)return x;
	return gcd(y, x % y);
}
int lcm(int x, int y) {
	return x / gcd(x, y) * y;
}

//繰り返し二乗法(Nが大きい時の場合のcombination)
int mod_pow(int x, int y, int m) {
	int res = 1;
	while (y) {
		if (y & 1) {
			res = res * x % m;
		}
		x = x * x % m;
		y >>= 1;
	}
	return res;
}
int kai(int x, int y) {
	int res = 1;
	for (int i = x - y + 1; i <= x; i++) {
		res *= (i % mod); res %= mod;
	}
	return res;
}
int comb(int x, int y) {
	if (y > x)return 0;
	return kai(x, y) * mod_pow(kai(y, y), mod - 2, mod) % mod;
}
//UnionFind
class UnionFind {
protected:
	int* par, * rank, * size;
public:
	UnionFind(unsigned int size) {
		par = new int[size];
		rank = new int[size];
		this->size = new int[size];
		rep(i, size) {
			par[i] = i;
			rank[i] = 0;
			this->size[i] = 1;
		}
	}
	int find(int n) {
		if (par[n] == n)return n;
		return par[n] = find(par[n]);
	}
	void unite(int n, int m) {
		n = find(n);
		m = find(m);
		if (n == m)return;
		if (rank[n] < rank[m]) {
			par[n] = m;
			size[m] += size[n];
		}
		else {
			par[m] = n;
			size[n] += size[m];
			if (rank[n] == rank[m])rank[n]++;
		}
	}
	bool same(int n, int m) {
		return find(n) == find(m);
	}
	int getsize(int n) {
		return size[find(n)];
	}
};
int dight(int n) {
	int ans = 1;
	while (n >= 10) {
		n /= 10;
		ans++;
	}
	return ans;
}
int dight_sum(int n) {
	int sum = 0;
	rep(i, 20)sum += (n % (int)pow(10, i + 1)) / (pow(10, i));
	return sum;
}
int dight_min(int n) {
	int ans = 9;
	while (n >= 10) {
		ans = min(ans, n % 10);
		n /= 10;
	}
	ans = min(ans, n);
	return ans;
}
int dight_max(int n) {
	int ans = 0;
	while (n >= 10) {
		ans = max(ans, n % 10);
		n /= 10;
	}
	ans = max(ans, n);
	return ans;
}
long long modinv(long long a, long long m) {
	long long b = m, u = 1, v = 0;
	while (b) {
		long long t = a / b;
		a -= t * b; swap(a, b);
		u -= t * v; swap(u, v);
	}
	u %= m;
	if (u < 0) u += m;
	return u;
}
int n, m, a, b;
signed main() {
	COMinit();
	cin >> n >> m >> a >> b;
	assert(2 <= n <= 100000);
	assert(2 <= m <= 300000);
	assert(1 <= a <= b <= m - 1);
	if (a * (n - 1) > b) {
		cout << 0 << endl;
		ggr
	}
	int memo = 1;
	for (int i = 1; i <= n; i++) {
		memo *= i;
		memo %= mod;
	}
	int ans = 0;
	for (int i = 0; i <= b - a * (n - 1); i++) {
		int p = (m - a * (n - 1) - i) % mod;
		ans += (p * COMB(i + n - 2, i)) % mod;
		ans %= mod;
	}
	cout << ans * memo % mod << endl;
	ggr
}
0