結果
| 問題 |
No.2406 Difference of Coordinate Squared
|
| コンテスト | |
| ユーザー |
gew1fw
|
| 提出日時 | 2025-06-12 20:50:10 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 2,932 bytes |
| コンパイル時間 | 615 ms |
| コンパイル使用メモリ | 81,808 KB |
| 実行使用メモリ | 92,760 KB |
| 最終ジャッジ日時 | 2025-06-12 20:54:09 |
| 合計ジャッジ時間 | 7,006 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 47 WA * 8 |
ソースコード
import sys
import math
MOD = 998244353
def main():
N, M = map(int, sys.stdin.readline().split())
max_n = N
# Precompute factorial and inverse factorial modulo MOD
factorial = [1] * (max_n + 1)
for i in range(1, max_n + 1):
factorial[i] = factorial[i-1] * i % MOD
inv_fact = [1] * (max_n + 1)
inv_fact[max_n] = pow(factorial[max_n], MOD-2, MOD)
for i in range(max_n-1, -1, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
def comb(n, k):
if k < 0 or k > n:
return 0
return factorial[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD
# Enumerate all (s, t) pairs where s*t = M and s and t have the same parity
factors = []
if M == 0:
# All pairs (s, 0) and (0, t), but s and t must have the same parity
# So s must be even (since 0 is even)
factors.append((0, 0))
else:
absM = abs(M)
for i in range(1, int(math.isqrt(absM)) + 1):
if absM % i == 0:
j = absM // i
for s_sign in [1, -1]:
for t_sign in [1, -1]:
s = s_sign * i
t = t_sign * j
if M == s * t and (s + t) % 2 == 0:
factors.append((s, t))
if i != j:
s = s_sign * j
t = t_sign * i
if M == s * t and (s + t) % 2 == 0:
factors.append((s, t))
# Deduplicate factors
factors = list(set(factors))
inv4 = pow(4, MOD-2, MOD)
inv4_powN = pow(inv4, N, MOD)
total = 0
for s, t in factors:
k = (s + t) // 2
l = (t - s) // 2
# Check if a and b exist
a_min = max(abs(k), 0)
b_min = abs(l)
a_max = N - b_min
if a_min > a_max:
continue
# Check parity
if (k + (N - l)) % 2 != 0:
continue
# a must be >= a_min, <= a_max, and a ≡ k mod 2
a_start = a_min
if a_start % 2 != k % 2:
a_start += 1
if a_start > a_max:
continue
a_end = a_max
if a_end % 2 != k % 2:
a_end -= 1
if a_end < a_start:
continue
# Compute the sum over a in a_start, a_start+2, ..., a_end
sum_contribution = 0
for a in range(a_start, a_end + 1, 2):
b = N - a
if b < abs(l):
continue
if (b - l) % 2 != 0:
continue
c1 = comb(N, a)
c2 = comb(a, (a + k) // 2)
c3 = comb(b, (b + l) // 2)
sum_contribution = (sum_contribution + c1 * c2 % MOD * c3) % MOD
total = (total + sum_contribution) % MOD
# Multiply by inv4^N
total = total * inv4_powN % MOD
print(total)
if __name__ == "__main__":
main()
gew1fw