結果

問題 No.3394 Big Binom
コンテスト
ユーザー V_Melville
提出日時 2025-12-01 20:29:21
言語 C++23
(gcc 13.3.0 + boost 1.89.0)
結果
TLE  
実行時間 -
コード長 3,611 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 633 ms
コンパイル使用メモリ 36,520 KB
実行使用メモリ 16,208 KB
最終ジャッジ日時 2025-12-01 20:29:26
合計ジャッジ時間 4,060 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3 TLE * 1
other -- * 21
権限があれば一括ダウンロードができます
コンパイルメッセージ
main.cpp: In function ‘int main()’:
main.cpp:119:14: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
  119 |         scanf("%d%d", &n, &k);
      |         ~~~~~^~~~~~~~~~~~~~~~

ソースコード

diff #
raw source code

#include <cstdint>
#include <cstdio>
#include <ctime>

using u32 = unsigned int;
using u64 = unsigned long long;
using i64 = long long;

static u32 mod, r, n2_;

i64 exgcd(i64 a, i64 b, i64 &x, i64 &y) {
	i64 d = a;
	if (b == 0) x = 1, y = 0;
	else d = exgcd(b, a % b, y, x), y -= a / b * x;
	return d;
}
inline u32 mul(u32 x, u32 y) {
	unsigned ret = (1ull * x * y + 1ull * (u32(1ull * x * y) * r) * mod) >> 32;
	return ret < mod ? ret : ret - mod;
}
inline u32 add(u32 a, u32 b) { u32 res = a + b; return res >= mod ? res - mod : res; } 

struct u32x8 {
	u32 v[8];

	u32x8() = default;
	u32x8(u32 val) { for (int i = 0; i < 8; i++) v[i] = val; }
	u32x8 operator+(const u32x8& other) const { u32x8 result; for (int i = 0; i < 8; i++) result.v[i] = add(v[i], other.v[i]); return result; }
	u32x8 operator*(const u32x8& other) const { u32x8 result; for (int i = 0; i < 8; i++) result.v[i] = mul(v[i], other.v[i]); return result; }
	u32x8& operator+=(const u32x8& other) { for (int i = 0; i < 8; i++) v[i] = add(v[i], other.v[i]); return *this; }
	u32x8& operator*=(const u32x8& other) { for (int i = 0; i < 8; i++) v[i] = mul(v[i], other.v[i]); return *this; }
};

inline u32 mon_in(u32 x) { return mul(x, n2_); }

inline u32 mon_out(u32 x) { u32 ret = ((x + (u64)((u32)x * r) * mod) >> 32); return ret < mod ? ret : ret - mod; }

inline u32 qpow(u32 n, u32 m, u32 p) {
	if (!m) return 1;
	u32 ret = qpow(n, m >> 1, p);
	ret = (u64)ret * ret % p;
	if (m & 1) return (u64)ret * n % p;
	else return ret;
}

u32 solve(int N, u32 p) {
	const int mv = 8;
	int n = N - (N & (256 * (1 << mv) - 1));
	
	mod = p;
	n2_ = -(u64)mod % mod;
	
	i64 x, y;
	exgcd(mod, 1ll << 32, x, y);
	r = -u32(x);
	
	u32 as = mon_in(1);
	u32 as2 = mon_in(1);
	
	u32x8 ans[8];
	u32x8 ml[8];
	u32x8 ad_val(mon_in(64));

	for (int i = 0; i < 8; i++) ans[i] = u32x8(mon_in(1)); 
	for (int i = 0; i < 8; i++) 
		for (int j = 0; j < 8; j++)
			ml[i].v[j] = mon_in(i * 8 + j + 1);
	
	for (unsigned i = 1; i + 63 <= (n >> mv); i += 64) 
		for (int j = 0; j < 8; j++) {
			ans[j] *= ml[j];
			ml[j] += ad_val;
		}

	for (int i = 0; i < 8; i++) { 
		u32 odd_prod = mon_in(1);
		u32 even_prod = mon_in(1);

		for (int j = 0; j < 8; j++)
			if (j & 1) odd_prod = mul(odd_prod, ans[i].v[j]);
			else even_prod = mul(even_prod, ans[i].v[j]);

		as = mul(as, odd_prod);
		as2 = mul(as2, even_prod);
	}

	for (int j = mv - 1; j >= 0; j--) {
		as = mul(as, as2);
		as = mul(as, mon_in(qpow(2, n >> (j + 1), p)));
		
		u32x8 inner_ans[8];
		u32x8 inner_ml[8];
		u32x8 inner_ad_val(mon_in(128));
		
		const unsigned add_ = n >> (j + 1);
		for (int i = 0; i < 8; i++) inner_ans[i] = u32x8(mon_in(1)); 
		for (int i = 0; i < 8; i++) 
			for (int k = 0; k < 8; k++)
				inner_ml[i].v[k] = mon_in(add_ + i * 16 + k * 2 + 1);
		for (unsigned i = add_; i + 127 <= (n >> j); i += 128) 
			for (int k = 0; k < 8; k++) {
				inner_ans[k] *= inner_ml[k];
				inner_ml[k] += inner_ad_val;
			}
		for (int i = 0; i < 8; i++) { 
			u32 prod0 = mul(mul(inner_ans[i].v[0], inner_ans[i].v[1]), mul(inner_ans[i].v[2], inner_ans[i].v[3]));
			u32 prod1 = mul(mul(inner_ans[i].v[4], inner_ans[i].v[5]), mul(inner_ans[i].v[6], inner_ans[i].v[7]));
			as2 = mul(as2, mul(prod0, prod1));
		}
	}
	as = mul(as, as2);
	as = mon_out(as);
	for (int i = n + 1; i <= N; i++) as = (u64)as * i % p; 
	return as;
}

int main() {
	int n, k;
	scanf("%d%d", &n, &k);
	int p = 998244353;
	n %= p;
	
	u32 a = solve(n, p);
	u32 b = solve(n-k, p);
	u32 c = solve(k, p);
	u32 invb = qpow(b, p-2, p);
	u32 invc = qpow(c, p-2, p);
	i64 ans = (i64)a*invb%p*invc%p;
	printf("%lld\n", ans);
	
	return 0;
}
0