結果
問題 |
No.2406 Difference of Coordinate Squared
|
ユーザー |
![]() |
提出日時 | 2025-06-12 14:02:46 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 3,140 bytes |
コンパイル時間 | 317 ms |
コンパイル使用メモリ | 82,032 KB |
実行使用メモリ | 93,568 KB |
最終ジャッジ日時 | 2025-06-12 14:03:33 |
合計ジャッジ時間 | 7,126 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 47 WA * 8 |
ソースコード
import sys import math MOD = 998244353 def main(): N, M = map(int, sys.stdin.readline().split()) if M == 0: # Special case, X^2 = Y^2 pass # Handle separately if needed, but current code can process it # Precompute factorials and inverse factorials modulo MOD max_n = N fact = [1] * (max_n + 1) for i in range(1, max_n + 1): fact[i] = fact[i-1] * i % MOD inv_fact = [1] * (max_n + 1) inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD) for i in range(max_n-1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD def comb(n, k): if k < 0 or k > n: return 0 return fact[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD # Function to compute all (a, b) pairs such that a * b = M def get_divisors(M): if M == 0: return set() divisors = set() # Find all positive divisors of |M| m_abs = abs(M) for i in range(1, int(math.isqrt(m_abs)) + 1): if m_abs % i == 0: divisors.add(i) divisors.add(m_abs // i) # Generate all possible a = d * s where s is ±1 a_list = [] for d in divisors: a_list.append(d) a_list.append(-d) pairs = set() for a in a_list: if M % a != 0: continue b = M // a pairs.add((a, b)) return pairs pairs = get_divisors(M) sum_total = 0 inv4_powN = pow(4, N, MOD) inv4_powN = pow(inv4_powN, MOD-2, MOD) # since (1/4)^N = 4^(-N) mod MOD for a, b in pairs: if a % 2 != b % 2: continue s_x = (a + b) // 2 s_y = (b - a) // 2 # Check congruence conditions for k if (s_x % 2) != ((N - s_y) % 2): continue t = s_x % 2 # Determine the range of k L = max(abs(s_x), 0) U = N - abs(s_y) if U < L: continue # Find first_k and last_k if L % 2 == t: first_k = L else: first_k = L + 1 if first_k > U: continue if U % 2 == t: last_k = U else: last_k = U - 1 if last_k < first_k: continue # Iterate over k from first_k to last_k, step 2 step = 2 for k in range(first_k, last_k + 1, step): m = N - k # Check x_plus and y_plus x_plus = (k + s_x) // 2 if x_plus < 0 or x_plus > k: continue y_plus = (m + s_y) // 2 if y_plus < 0 or y_plus > m: continue # Compute combinations c_n_k = comb(N, k) c_k_x = comb(k, x_plus) c_m_y = comb(m, y_plus) term = c_n_k * c_k_x % MOD term = term * c_m_y % MOD sum_total = (sum_total + term) % MOD # Multiply by inv4^N result = sum_total * inv4_powN % MOD print(result) if __name__ == "__main__": main()