結果
問題 |
No.2406 Difference of Coordinate Squared
|
ユーザー |
![]() |
提出日時 | 2025-04-16 16:47:45 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 3,228 bytes |
コンパイル時間 | 543 ms |
コンパイル使用メモリ | 81,440 KB |
実行使用メモリ | 92,696 KB |
最終ジャッジ日時 | 2025-04-16 16:50:37 |
合計ジャッジ時間 | 7,379 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 47 WA * 8 |
ソースコード
MOD = 998244353 def main(): import sys N, M = map(int, sys.stdin.readline().split()) max_n = N # Precompute factorial and inverse factorial modulo MOD factorial = [1] * (max_n + 1) for i in range(1, max_n + 1): factorial[i] = factorial[i-1] * i % MOD inv_fact = [1] * (max_n + 1) inv_fact[max_n] = pow(factorial[max_n], MOD-2, MOD) for i in range(max_n-1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD def comb(n, k): if k < 0 or k > n: return 0 return factorial[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD if M == 0: print(0) return # Generate all factor pairs (A, B) of M with same parity factors = [] m_abs = abs(M) if m_abs == 0: pass else: for d in range(1, int(m_abs**0.5) + 1): if m_abs % d == 0: A1 = d B1 = M // d if (A1 + B1) % 2 == 0: factors.append((A1, B1)) A2 = -d B2 = -M // d if (A2 + B2) % 2 == 0: factors.append((A2, B2)) if d != m_abs // d: A1 = m_abs // d B1 = M // A1 if (A1 + B1) % 2 == 0: factors.append((A1, B1)) A2 = -A1 B2 = -B1 if (A2 + B2) % 2 == 0: factors.append((A2, B2)) seen = set() unique_factors = [] for A, B in factors: if (A, B) not in seen: seen.add((A, B)) unique_factors.append((A, B)) factors = unique_factors total = 0 for A, B in factors: X = (A + B) // 2 Y = (B - A) // 2 if (A + B) % 2 != 0 or (B - A) % 2 != 0: continue if (X + Y) % 2 != N % 2: continue a_min = max(abs(X), 0) a_max = N - abs(Y) if a_min > a_max: continue a_parity = X % 2 start = a_min if a_min % 2 == a_parity else a_min + 1 end = a_max if a_max % 2 == a_parity else a_max - 1 if start > end: continue num_a = ((end - start) // 2) + 1 first_a = start last_a = end step = 2 a = first_a while a <= last_a: b = N - a if b < 0: a += step continue if abs(Y) > b: a += step continue if (Y + b) % 2 != 0: a += step continue l = (Y + b) // 2 if l < 0 or l > b: a += step continue k = (X + a) // 2 if k < 0 or k > a: a += step continue c = comb(N, a) c_a = comb(a, k) c_b = comb(b, l) total = (total + c * c_a % MOD * c_b) % MOD a += step inv_4 = pow(4, MOD-2, MOD) inv_4n = pow(inv_4, N, MOD) ans = total * inv_4n % MOD print(ans) if __name__ == '__main__': main()