結果
問題 |
No.2237 Xor Sum Hoge
|
ユーザー |
![]() |
提出日時 | 2025-06-12 21:38:11 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 2,777 bytes |
コンパイル時間 | 164 ms |
コンパイル使用メモリ | 82,280 KB |
実行使用メモリ | 270,768 KB |
最終ジャッジ日時 | 2025-06-12 21:42:48 |
合計ジャッジ時間 | 22,750 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | -- * 3 |
other | TLE * 1 -- * 31 |
ソースコード
MOD = 998244353 def main(): import sys input = sys.stdin.read data = input().split() N = int(data[0]) B = int(data[1]) C = int(data[2]) if B < C: print(0) return max_n = N fact = [1] * (max_n + 1) for i in range(1, max_n + 1): fact[i] = fact[i-1] * i % MOD inv_fact = [1] * (max_n + 1) inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD) for i in range(max_n -1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD def comb(n, k): if k < 0 or k > n: return 0 return fact[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD bin_B = [0] * 61 bin_C = [0] * 61 for k in range(60, -1, -1): bin_B[k] = (B >> k) & 1 bin_C[k] = (C >> k) & 1 C_n = [0] * (N+1) for k in range(N+1): C_n[k] = comb(N, k) prefix_sums = [0] * (N+2) for s in range(N+1): prefix_sums[s+1] = (prefix_sums[s] + C_n[s]) % MOD prefix_alt_sums = [0] * (N+2) for s in range(N+1): sign = 1 if s % 2 == 0 else -1 term = C_n[s] * sign prefix_alt_sums[s+1] = (prefix_alt_sums[s] + term) % MOD inv_2 = (MOD + 1) // 2 from collections import defaultdict dp = [defaultdict(int) for _ in range(62)] dp[0][0] = 1 for k in range(61): current_dp = dp[k] b_k = bin_B[k] c_k = bin_C[k] for carry_in, ways in current_dp.items(): if (carry_in + (b_k - c_k)) % 2 != 0: continue a = max(0, b_k - carry_in) a = max(a, 0) a_val = a b_val = N if a_val > b_val: continue sum_total = (prefix_sums[b_val + 1] - prefix_sums[a_val]) % MOD sum_alt_total = (prefix_alt_sums[b_val + 1] - prefix_alt_sums[a_val]) % MOD if c_k % 2 == 0: sum_ways = (sum_total + sum_alt_total) * inv_2 % MOD else: sum_ways = (sum_total - sum_alt_total) * inv_2 % MOD sum_ways = sum_ways % MOD min_s = 2 * 0 + (b_k - carry_in) if min_s < a_val: required_min_carry = (a_val - (b_k - carry_in) + 1) // 2 min_carry = max(0, required_min_carry) else: min_carry = 0 max_carry = (N - (b_k - carry_in)) // 2 for carry_out in range(min_carry, max_carry + 1): s_k = 2 * carry_out + (b_k - carry_in) if s_k < a_val or s_k > b_val or (s_k % 2) != c_k: continue dp[k+1][carry_out] = (dp[k+1][carry_out] + ways * C_n[s_k]) % MOD print(dp[61].get(0, 0) % MOD) if __name__ == '__main__': main()