結果
| 問題 |
No.2237 Xor Sum Hoge
|
| コンテスト | |
| ユーザー |
gew1fw
|
| 提出日時 | 2025-06-12 16:48:16 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 2,777 bytes |
| コンパイル時間 | 173 ms |
| コンパイル使用メモリ | 82,308 KB |
| 実行使用メモリ | 270,880 KB |
| 最終ジャッジ日時 | 2025-06-12 16:49:40 |
| 合計ジャッジ時間 | 22,913 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | -- * 3 |
| other | TLE * 1 -- * 31 |
ソースコード
MOD = 998244353
def main():
import sys
input = sys.stdin.read
data = input().split()
N = int(data[0])
B = int(data[1])
C = int(data[2])
if B < C:
print(0)
return
max_n = N
fact = [1] * (max_n + 1)
for i in range(1, max_n + 1):
fact[i] = fact[i-1] * i % MOD
inv_fact = [1] * (max_n + 1)
inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD)
for i in range(max_n -1, -1, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
def comb(n, k):
if k < 0 or k > n:
return 0
return fact[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD
bin_B = [0] * 61
bin_C = [0] * 61
for k in range(60, -1, -1):
bin_B[k] = (B >> k) & 1
bin_C[k] = (C >> k) & 1
C_n = [0] * (N+1)
for k in range(N+1):
C_n[k] = comb(N, k)
prefix_sums = [0] * (N+2)
for s in range(N+1):
prefix_sums[s+1] = (prefix_sums[s] + C_n[s]) % MOD
prefix_alt_sums = [0] * (N+2)
for s in range(N+1):
sign = 1 if s % 2 == 0 else -1
term = C_n[s] * sign
prefix_alt_sums[s+1] = (prefix_alt_sums[s] + term) % MOD
inv_2 = (MOD + 1) // 2
from collections import defaultdict
dp = [defaultdict(int) for _ in range(62)]
dp[0][0] = 1
for k in range(61):
current_dp = dp[k]
b_k = bin_B[k]
c_k = bin_C[k]
for carry_in, ways in current_dp.items():
if (carry_in + (b_k - c_k)) % 2 != 0:
continue
a = max(0, b_k - carry_in)
a = max(a, 0)
a_val = a
b_val = N
if a_val > b_val:
continue
sum_total = (prefix_sums[b_val + 1] - prefix_sums[a_val]) % MOD
sum_alt_total = (prefix_alt_sums[b_val + 1] - prefix_alt_sums[a_val]) % MOD
if c_k % 2 == 0:
sum_ways = (sum_total + sum_alt_total) * inv_2 % MOD
else:
sum_ways = (sum_total - sum_alt_total) * inv_2 % MOD
sum_ways = sum_ways % MOD
min_s = 2 * 0 + (b_k - carry_in)
if min_s < a_val:
required_min_carry = (a_val - (b_k - carry_in) + 1) // 2
min_carry = max(0, required_min_carry)
else:
min_carry = 0
max_carry = (N - (b_k - carry_in)) // 2
for carry_out in range(min_carry, max_carry + 1):
s_k = 2 * carry_out + (b_k - carry_in)
if s_k < a_val or s_k > b_val or (s_k % 2) != c_k:
continue
dp[k+1][carry_out] = (dp[k+1][carry_out] + ways * C_n[s_k]) % MOD
print(dp[61].get(0, 0) % MOD)
if __name__ == '__main__':
main()
gew1fw