結果

問題 No.2237 Xor Sum Hoge
ユーザー gew1fw
提出日時 2025-06-12 12:58:00
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,030 bytes
コンパイル時間 185 ms
コンパイル使用メモリ 82,684 KB
実行使用メモリ 279,680 KB
最終ジャッジ日時 2025-06-12 13:05:08
合計ジャッジ時間 22,749 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 3
other TLE * 1 -- * 31
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

mod = 998244353

def main():
    n, B, C = map(int, sys.stdin.readline().split())

    # Precompute the bits of B and C
    b = [(B >> k) & 1 for k in range(60)]
    c = [(C >> k) & 1 for k in range(60)]

    # Precompute factorial and inverse factorial modulo mod
    max_n = n
    fact = [1] * (max_n + 1)
    for i in range(1, max_n + 1):
        fact[i] = fact[i-1] * i % mod

    inv_fact = [1] * (max_n + 1)
    inv_fact[max_n] = pow(fact[max_n], mod-2, mod)
    for i in range(max_n - 1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % mod

    # Precompute combination values C(n, t)
    comb_pre = [0] * (n + 1)
    for t in range(n + 1):
        if t > n:
            comb_pre[t] = 0
        else:
            comb_pre[t] = fact[n] * inv_fact[t] % mod * inv_fact[n - t] % mod

    dp = defaultdict(int)
    dp[0] = 1

    for k in range(60):
        new_dp = defaultdict(int)
        bk = b[k]
        ck = c[k]
        for carry, cnt in dp.items():
            # Check if (carry + ck) mod 2 == bk mod 2
            if (carry + ck) % 2 != bk % 2:
                continue
            # Determine the parity required for t
            r = ck
            # Calculate m which is part of the formula for c_next
            m = (r + carry - bk) // 2
            # Determine the valid range for c_next
            min_cnext = (carry - bk + 1) // 2
            max_cnext = (n + carry - bk) // 2

            if min_cnext > max_cnext:
                continue

            # Iterate over possible c_next values
            for cnext in range(min_cnext, max_cnext + 1):
                t = 2 * cnext + bk - carry
                if t < 0 or t > n:
                    continue
                if t % 2 != r:
                    continue
                comb_val = comb_pre[t]
                new_dp[cnext] = (new_dp[cnext] + cnt * comb_val) % mod

        dp = new_dp
        if not dp:
            break

    print(dp.get(0, 0) % mod)

if __name__ == "__main__":
    main()
0