結果

問題 No.2406 Difference of Coordinate Squared
ユーザー lam6er
提出日時 2025-04-16 16:47:45
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,228 bytes
コンパイル時間 543 ms
コンパイル使用メモリ 81,440 KB
実行使用メモリ 92,696 KB
最終ジャッジ日時 2025-04-16 16:50:37
合計ジャッジ時間 7,379 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 47 WA * 8
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def main():
    import sys
    N, M = map(int, sys.stdin.readline().split())
    
    max_n = N
    # Precompute factorial and inverse factorial modulo MOD
    factorial = [1] * (max_n + 1)
    for i in range(1, max_n + 1):
        factorial[i] = factorial[i-1] * i % MOD
    inv_fact = [1] * (max_n + 1)
    inv_fact[max_n] = pow(factorial[max_n], MOD-2, MOD)
    for i in range(max_n-1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    
    def comb(n, k):
        if k < 0 or k > n:
            return 0
        return factorial[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD
    
    if M == 0:
        print(0)
        return
    
    # Generate all factor pairs (A, B) of M with same parity
    factors = []
    m_abs = abs(M)
    if m_abs == 0:
        pass
    else:
        for d in range(1, int(m_abs**0.5) + 1):
            if m_abs % d == 0:
                A1 = d
                B1 = M // d
                if (A1 + B1) % 2 == 0:
                    factors.append((A1, B1))
                A2 = -d
                B2 = -M // d
                if (A2 + B2) % 2 == 0:
                    factors.append((A2, B2))
                if d != m_abs // d:
                    A1 = m_abs // d
                    B1 = M // A1
                    if (A1 + B1) % 2 == 0:
                        factors.append((A1, B1))
                    A2 = -A1
                    B2 = -B1
                    if (A2 + B2) % 2 == 0:
                        factors.append((A2, B2))
    
    seen = set()
    unique_factors = []
    for A, B in factors:
        if (A, B) not in seen:
            seen.add((A, B))
            unique_factors.append((A, B))
    factors = unique_factors
    
    total = 0
    for A, B in factors:
        X = (A + B) // 2
        Y = (B - A) // 2
        
        if (A + B) % 2 != 0 or (B - A) % 2 != 0:
            continue
        
        if (X + Y) % 2 != N % 2:
            continue
        
        a_min = max(abs(X), 0)
        a_max = N - abs(Y)
        if a_min > a_max:
            continue
        
        a_parity = X % 2
        start = a_min if a_min % 2 == a_parity else a_min + 1
        end = a_max if a_max % 2 == a_parity else a_max - 1
        if start > end:
            continue
        
        num_a = ((end - start) // 2) + 1
        
        first_a = start
        last_a = end
        step = 2
        
        a = first_a
        while a <= last_a:
            b = N - a
            if b < 0:
                a += step
                continue
            if abs(Y) > b:
                a += step
                continue
            if (Y + b) % 2 != 0:
                a += step
                continue
            l = (Y + b) // 2
            if l < 0 or l > b:
                a += step
                continue
            k = (X + a) // 2
            if k < 0 or k > a:
                a += step
                continue
            c = comb(N, a)
            c_a = comb(a, k)
            c_b = comb(b, l)
            total = (total + c * c_a % MOD * c_b) % MOD
            a += step
    
    inv_4 = pow(4, MOD-2, MOD)
    inv_4n = pow(inv_4, N, MOD)
    ans = total * inv_4n % MOD
    print(ans)

if __name__ == '__main__':
    main()
0