結果

問題 No.2406 Difference of Coordinate Squared
ユーザー gew1fw
提出日時 2025-06-12 20:49:58
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,903 bytes
コンパイル時間 231 ms
コンパイル使用メモリ 82,304 KB
実行使用メモリ 93,056 KB
最終ジャッジ日時 2025-06-12 20:53:11
合計ジャッジ時間 6,639 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 47 WA * 8
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def main():
    import sys
    N, M = map(int, sys.stdin.readline().split())
    
    max_n = N
    fact = [1] * (max_n + 1)
    for i in range(1, max_n + 1):
        fact[i] = fact[i-1] * i % MOD
    inv_fact = [1] * (max_n + 1)
    inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD)
    for i in range(max_n-1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    
    inv_2_2N = pow(4, N, MOD)
    inv_2_2N = pow(inv_2_2N, MOD-2, MOD)
    
    def comb(n, k):
        if k < 0 or k > n:
            return 0
        return fact[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD
    
    def get_factors(m):
        if m == 0:
            return []
        m_abs = abs(m)
        factors = {}
        i = 2
        while i * i <= m_abs:
            while m_abs % i == 0:
                factors[i] = factors.get(i, 0) + 1
                m_abs //= i
            i += 1
        if m_abs > 1:
            factors[m_abs] = factors.get(m_abs, 0) + 1
        
        divisors = [1]
        for p, exp in factors.items():
            temp = []
            for d in divisors:
                current = d
                for e in range(1, exp + 1):
                    current *= p
                    temp.append(current)
            divisors += temp
        
        divisors = list(set(divisors))
        all_divisors = []
        for d in divisors:
            all_divisors.append(d)
            all_divisors.append(-d)
        all_divisors = list(set(all_divisors))
        
        factor_pairs = []
        for a in all_divisors:
            if a == 0:
                continue
            if m % a != 0:
                continue
            b = m // a
            if (a % 2) == (b % 2):
                factor_pairs.append((a, b))
        return factor_pairs
    
    factor_pairs = get_factors(M)
    
    total = 0
    seen = set()
    for a, b in factor_pairs:
        if (a, b) in seen:
            continue
        seen.add((a, b))
        x = (a + b) // 2
        z = (b - a) // 2
        
        if (x + (N - z)) % 2 != 0:
            continue
        
        k_min = max(abs(x), 0)
        k_max = N - abs(z)
        if k_min > k_max:
            continue
        
        parity = x % 2
        start = k_min if (k_min % 2 == parity) else k_min + 1
        if start > k_max:
            continue
        
        step = 2
        num_terms = ((k_max - start) // step) + 1
        
        for k in range(start, k_max + 1, step):
            l = N - k
            if l < abs(z):
                continue
            a_x = (x + k) // 2
            b_z = (z + l) // 2
            c_nk = comb(N, k)
            c_ka = comb(k, a_x)
            c_lb = comb(l, b_z)
            term = c_nk * c_ka % MOD
            term = term * c_lb % MOD
            total = (total + term) % MOD
    
    total = total * inv_2_2N % MOD
    print(total)

if __name__ == '__main__':
    main()
0