結果

問題 No.2406 Difference of Coordinate Squared
ユーザー gew1fw
提出日時 2025-06-12 16:03:51
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,932 bytes
コンパイル時間 185 ms
コンパイル使用メモリ 82,432 KB
実行使用メモリ 93,056 KB
最終ジャッジ日時 2025-06-12 16:03:59
合計ジャッジ時間 7,462 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 47 WA * 8
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import math
MOD = 998244353

def main():
    N, M = map(int, sys.stdin.readline().split())
    max_n = N
    # Precompute factorial and inverse factorial modulo MOD
    factorial = [1] * (max_n + 1)
    for i in range(1, max_n + 1):
        factorial[i] = factorial[i-1] * i % MOD
    inv_fact = [1] * (max_n + 1)
    inv_fact[max_n] = pow(factorial[max_n], MOD-2, MOD)
    for i in range(max_n-1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD

    def comb(n, k):
        if k < 0 or k > n:
            return 0
        return factorial[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD

    # Enumerate all (s, t) pairs where s*t = M and s and t have the same parity
    factors = []
    if M == 0:
        # All pairs (s, 0) and (0, t), but s and t must have the same parity
        # So s must be even (since 0 is even)
        factors.append((0, 0))
    else:
        absM = abs(M)
        for i in range(1, int(math.isqrt(absM)) + 1):
            if absM % i == 0:
                j = absM // i
                for s_sign in [1, -1]:
                    for t_sign in [1, -1]:
                        s = s_sign * i
                        t = t_sign * j
                        if M == s * t and (s + t) % 2 == 0:
                            factors.append((s, t))
                        if i != j:
                            s = s_sign * j
                            t = t_sign * i
                            if M == s * t and (s + t) % 2 == 0:
                                factors.append((s, t))
        # Deduplicate factors
        factors = list(set(factors))

    inv4 = pow(4, MOD-2, MOD)
    inv4_powN = pow(inv4, N, MOD)
    total = 0

    for s, t in factors:
        k = (s + t) // 2
        l = (t - s) // 2
        # Check if a and b exist
        a_min = max(abs(k), 0)
        b_min = abs(l)
        a_max = N - b_min
        if a_min > a_max:
            continue
        # Check parity
        if (k + (N - l)) % 2 != 0:
            continue
        # a must be >= a_min, <= a_max, and a ≡ k mod 2
        a_start = a_min
        if a_start % 2 != k % 2:
            a_start += 1
        if a_start > a_max:
            continue
        a_end = a_max
        if a_end % 2 != k % 2:
            a_end -= 1
        if a_end < a_start:
            continue
        # Compute the sum over a in a_start, a_start+2, ..., a_end
        sum_contribution = 0
        for a in range(a_start, a_end + 1, 2):
            b = N - a
            if b < abs(l):
                continue
            if (b - l) % 2 != 0:
                continue
            c1 = comb(N, a)
            c2 = comb(a, (a + k) // 2)
            c3 = comb(b, (b + l) // 2)
            sum_contribution = (sum_contribution + c1 * c2 % MOD * c3) % MOD
        total = (total + sum_contribution) % MOD

    # Multiply by inv4^N
    total = total * inv4_powN % MOD
    print(total)

if __name__ == "__main__":
    main()
0