結果
問題 |
No.2406 Difference of Coordinate Squared
|
ユーザー |
![]() |
提出日時 | 2025-06-12 20:50:10 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 2,932 bytes |
コンパイル時間 | 615 ms |
コンパイル使用メモリ | 81,808 KB |
実行使用メモリ | 92,760 KB |
最終ジャッジ日時 | 2025-06-12 20:54:09 |
合計ジャッジ時間 | 7,006 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 47 WA * 8 |
ソースコード
import sys import math MOD = 998244353 def main(): N, M = map(int, sys.stdin.readline().split()) max_n = N # Precompute factorial and inverse factorial modulo MOD factorial = [1] * (max_n + 1) for i in range(1, max_n + 1): factorial[i] = factorial[i-1] * i % MOD inv_fact = [1] * (max_n + 1) inv_fact[max_n] = pow(factorial[max_n], MOD-2, MOD) for i in range(max_n-1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD def comb(n, k): if k < 0 or k > n: return 0 return factorial[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD # Enumerate all (s, t) pairs where s*t = M and s and t have the same parity factors = [] if M == 0: # All pairs (s, 0) and (0, t), but s and t must have the same parity # So s must be even (since 0 is even) factors.append((0, 0)) else: absM = abs(M) for i in range(1, int(math.isqrt(absM)) + 1): if absM % i == 0: j = absM // i for s_sign in [1, -1]: for t_sign in [1, -1]: s = s_sign * i t = t_sign * j if M == s * t and (s + t) % 2 == 0: factors.append((s, t)) if i != j: s = s_sign * j t = t_sign * i if M == s * t and (s + t) % 2 == 0: factors.append((s, t)) # Deduplicate factors factors = list(set(factors)) inv4 = pow(4, MOD-2, MOD) inv4_powN = pow(inv4, N, MOD) total = 0 for s, t in factors: k = (s + t) // 2 l = (t - s) // 2 # Check if a and b exist a_min = max(abs(k), 0) b_min = abs(l) a_max = N - b_min if a_min > a_max: continue # Check parity if (k + (N - l)) % 2 != 0: continue # a must be >= a_min, <= a_max, and a ≡ k mod 2 a_start = a_min if a_start % 2 != k % 2: a_start += 1 if a_start > a_max: continue a_end = a_max if a_end % 2 != k % 2: a_end -= 1 if a_end < a_start: continue # Compute the sum over a in a_start, a_start+2, ..., a_end sum_contribution = 0 for a in range(a_start, a_end + 1, 2): b = N - a if b < abs(l): continue if (b - l) % 2 != 0: continue c1 = comb(N, a) c2 = comb(a, (a + k) // 2) c3 = comb(b, (b + l) // 2) sum_contribution = (sum_contribution + c1 * c2 % MOD * c3) % MOD total = (total + sum_contribution) % MOD # Multiply by inv4^N total = total * inv4_powN % MOD print(total) if __name__ == "__main__": main()