結果

問題 No.2237 Xor Sum Hoge
ユーザー gew1fw
提出日時 2025-06-12 21:38:11
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,777 bytes
コンパイル時間 164 ms
コンパイル使用メモリ 82,280 KB
実行使用メモリ 270,768 KB
最終ジャッジ日時 2025-06-12 21:42:48
合計ジャッジ時間 22,750 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 3
other TLE * 1 -- * 31
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    N = int(data[0])
    B = int(data[1])
    C = int(data[2])
    
    if B < C:
        print(0)
        return
    
    max_n = N
    fact = [1] * (max_n + 1)
    for i in range(1, max_n + 1):
        fact[i] = fact[i-1] * i % MOD
    inv_fact = [1] * (max_n + 1)
    inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD)
    for i in range(max_n -1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    
    def comb(n, k):
        if k < 0 or k > n:
            return 0
        return fact[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD
    
    bin_B = [0] * 61
    bin_C = [0] * 61
    for k in range(60, -1, -1):
        bin_B[k] = (B >> k) & 1
        bin_C[k] = (C >> k) & 1
    
    C_n = [0] * (N+1)
    for k in range(N+1):
        C_n[k] = comb(N, k)
    
    prefix_sums = [0] * (N+2)
    for s in range(N+1):
        prefix_sums[s+1] = (prefix_sums[s] + C_n[s]) % MOD
    
    prefix_alt_sums = [0] * (N+2)
    for s in range(N+1):
        sign = 1 if s % 2 == 0 else -1
        term = C_n[s] * sign
        prefix_alt_sums[s+1] = (prefix_alt_sums[s] + term) % MOD
    
    inv_2 = (MOD + 1) // 2
    
    from collections import defaultdict
    dp = [defaultdict(int) for _ in range(62)]
    dp[0][0] = 1
    
    for k in range(61):
        current_dp = dp[k]
        b_k = bin_B[k]
        c_k = bin_C[k]
        for carry_in, ways in current_dp.items():
            if (carry_in + (b_k - c_k)) % 2 != 0:
                continue
            a = max(0, b_k - carry_in)
            a = max(a, 0)
            a_val = a
            b_val = N
            if a_val > b_val:
                continue
            sum_total = (prefix_sums[b_val + 1] - prefix_sums[a_val]) % MOD
            sum_alt_total = (prefix_alt_sums[b_val + 1] - prefix_alt_sums[a_val]) % MOD
            if c_k % 2 == 0:
                sum_ways = (sum_total + sum_alt_total) * inv_2 % MOD
            else:
                sum_ways = (sum_total - sum_alt_total) * inv_2 % MOD
            sum_ways = sum_ways % MOD
            min_s = 2 * 0 + (b_k - carry_in)
            if min_s < a_val:
                required_min_carry = (a_val - (b_k - carry_in) + 1) // 2
                min_carry = max(0, required_min_carry)
            else:
                min_carry = 0
            max_carry = (N - (b_k - carry_in)) // 2
            for carry_out in range(min_carry, max_carry + 1):
                s_k = 2 * carry_out + (b_k - carry_in)
                if s_k < a_val or s_k > b_val or (s_k % 2) != c_k:
                    continue
                dp[k+1][carry_out] = (dp[k+1][carry_out] + ways * C_n[s_k]) % MOD
    
    print(dp[61].get(0, 0) % MOD)

if __name__ == '__main__':
    main()
0