結果

問題 No.2406 Difference of Coordinate Squared
ユーザー gew1fw
提出日時 2025-06-12 14:02:46
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,140 bytes
コンパイル時間 317 ms
コンパイル使用メモリ 82,032 KB
実行使用メモリ 93,568 KB
最終ジャッジ日時 2025-06-12 14:03:33
合計ジャッジ時間 7,126 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 47 WA * 8
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import math

MOD = 998244353

def main():
    N, M = map(int, sys.stdin.readline().split())

    if M == 0:
        # Special case, X^2 = Y^2
        pass  # Handle separately if needed, but current code can process it
    
    # Precompute factorials and inverse factorials modulo MOD
    max_n = N
    fact = [1] * (max_n + 1)
    for i in range(1, max_n + 1):
        fact[i] = fact[i-1] * i % MOD
    inv_fact = [1] * (max_n + 1)
    inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD)
    for i in range(max_n-1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    
    def comb(n, k):
        if k < 0 or k > n:
            return 0
        return fact[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD
    
    # Function to compute all (a, b) pairs such that a * b = M
    def get_divisors(M):
        if M == 0:
            return set()
        divisors = set()
        # Find all positive divisors of |M|
        m_abs = abs(M)
        for i in range(1, int(math.isqrt(m_abs)) + 1):
            if m_abs % i == 0:
                divisors.add(i)
                divisors.add(m_abs // i)
        # Generate all possible a = d * s where s is ±1
        a_list = []
        for d in divisors:
            a_list.append(d)
            a_list.append(-d)
        pairs = set()
        for a in a_list:
            if M % a != 0:
                continue
            b = M // a
            pairs.add((a, b))
        return pairs
    
    pairs = get_divisors(M)
    
    sum_total = 0
    inv4_powN = pow(4, N, MOD)
    inv4_powN = pow(inv4_powN, MOD-2, MOD)  # since (1/4)^N = 4^(-N) mod MOD
    
    for a, b in pairs:
        if a % 2 != b % 2:
            continue
        s_x = (a + b) // 2
        s_y = (b - a) // 2
        
        # Check congruence conditions for k
        if (s_x % 2) != ((N - s_y) % 2):
            continue
        t = s_x % 2
        
        # Determine the range of k
        L = max(abs(s_x), 0)
        U = N - abs(s_y)
        if U < L:
            continue
        
        # Find first_k and last_k
        if L % 2 == t:
            first_k = L
        else:
            first_k = L + 1
            if first_k > U:
                continue
        
        if U % 2 == t:
            last_k = U
        else:
            last_k = U - 1
            if last_k < first_k:
                continue
        
        # Iterate over k from first_k to last_k, step 2
        step = 2
        for k in range(first_k, last_k + 1, step):
            m = N - k
            # Check x_plus and y_plus
            x_plus = (k + s_x) // 2
            if x_plus < 0 or x_plus > k:
                continue
            y_plus = (m + s_y) // 2
            if y_plus < 0 or y_plus > m:
                continue
            # Compute combinations
            c_n_k = comb(N, k)
            c_k_x = comb(k, x_plus)
            c_m_y = comb(m, y_plus)
            term = c_n_k * c_k_x % MOD
            term = term * c_m_y % MOD
            sum_total = (sum_total + term) % MOD
    
    # Multiply by inv4^N
    result = sum_total * inv4_powN % MOD
    print(result)

if __name__ == "__main__":
    main()
0