結果

問題 No.2406 Difference of Coordinate Squared
ユーザー lam6er
提出日時 2025-03-31 17:33:49
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,329 bytes
コンパイル時間 325 ms
コンパイル使用メモリ 82,776 KB
実行使用メモリ 102,704 KB
最終ジャッジ日時 2025-03-31 17:34:18
合計ジャッジ時間 4,342 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other TLE * 1 -- * 54
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def main():
    import sys
    N, M = map(int, sys.stdin.readline().split())
    
    # Precompute factorial, inverse factorial modulo MOD up to N
    max_n = N
    fact = [1] * (max_n + 1)
    for i in range(1, max_n + 1):
        fact[i] = fact[i-1] * i % MOD
    inv_fact = [1] * (max_n + 1)
    inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD)
    for i in range(max_n - 1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    
    def comb(n, k):
        if n < 0 or k < 0 or k > n:
            return 0
        return fact[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD
    
    inv_4 = pow(4, MOD-2, MOD)
    inv_4_pows = [1] * (N + 1)
    for i in range(1, N + 1):
        inv_4_pows[i] = inv_4_pows[i-1] * inv_4 % MOD
    
    def get_divisors(m):
        divisors = set()
        for d in range(1, int(int(abs(m))**0.5) + 1):
            if m % d == 0:
                divisors.add(d)
                divisors.add(m // d)
                divisors.add(-d)
                divisors.add(-m // d)
        if m == 0:
            return []  # M=0, handled separately
        return list(divisors)
    
    divisors = get_divisors(M)
    if M == 0:
        divisors = [0]
    
    total = 0
    
    for u in divisors:
        if M == 0:
            for v in [0]:
                pass
        else:
            if u == 0:
                continue
            v = M // u
        
        if (u + v) % 2 != 0:
            continue
        
        X = (u + v) // 2
        Y = (u - v) // 2
        
        if (N + u) % 2 != 0:
            continue
        s = (N + u) // 2
        
        if s < 0:
            continue
        
        a_min = max((X + 1) // 2, 0)
        a_max = min(s, (N + X) // 2)
        
        if a_min > a_max:
            continue
        
        for a in range(a_min, a_max + 1):
            k = 2 * a - X
            if k < 0 or k > N:
                continue
            m_val = N - k
            b = s - a
            if b < 0 or b > m_val:
                continue
            
            term = comb(N, k)
            term = term * comb(k, a) % MOD
            term = term * comb(m_val, b) % MOD
            total = (total + term) % MOD
    
    inv_4_N = inv_4_pows[N]
    total = total * inv_4_N % MOD
    
    print(total)

if __name__ == "__main__":
    main()
0