結果
| 問題 |
No.2406 Difference of Coordinate Squared
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-16 00:34:16 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 3,228 bytes |
| コンパイル時間 | 310 ms |
| コンパイル使用メモリ | 82,268 KB |
| 実行使用メモリ | 93,000 KB |
| 最終ジャッジ日時 | 2025-04-16 00:35:57 |
| 合計ジャッジ時間 | 7,177 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 47 WA * 8 |
ソースコード
MOD = 998244353
def main():
import sys
N, M = map(int, sys.stdin.readline().split())
max_n = N
# Precompute factorial and inverse factorial modulo MOD
factorial = [1] * (max_n + 1)
for i in range(1, max_n + 1):
factorial[i] = factorial[i-1] * i % MOD
inv_fact = [1] * (max_n + 1)
inv_fact[max_n] = pow(factorial[max_n], MOD-2, MOD)
for i in range(max_n-1, -1, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
def comb(n, k):
if k < 0 or k > n:
return 0
return factorial[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD
if M == 0:
print(0)
return
# Generate all factor pairs (A, B) of M with same parity
factors = []
m_abs = abs(M)
if m_abs == 0:
pass
else:
for d in range(1, int(m_abs**0.5) + 1):
if m_abs % d == 0:
A1 = d
B1 = M // d
if (A1 + B1) % 2 == 0:
factors.append((A1, B1))
A2 = -d
B2 = -M // d
if (A2 + B2) % 2 == 0:
factors.append((A2, B2))
if d != m_abs // d:
A1 = m_abs // d
B1 = M // A1
if (A1 + B1) % 2 == 0:
factors.append((A1, B1))
A2 = -A1
B2 = -B1
if (A2 + B2) % 2 == 0:
factors.append((A2, B2))
seen = set()
unique_factors = []
for A, B in factors:
if (A, B) not in seen:
seen.add((A, B))
unique_factors.append((A, B))
factors = unique_factors
total = 0
for A, B in factors:
X = (A + B) // 2
Y = (B - A) // 2
if (A + B) % 2 != 0 or (B - A) % 2 != 0:
continue
if (X + Y) % 2 != N % 2:
continue
a_min = max(abs(X), 0)
a_max = N - abs(Y)
if a_min > a_max:
continue
a_parity = X % 2
start = a_min if a_min % 2 == a_parity else a_min + 1
end = a_max if a_max % 2 == a_parity else a_max - 1
if start > end:
continue
num_a = ((end - start) // 2) + 1
first_a = start
last_a = end
step = 2
a = first_a
while a <= last_a:
b = N - a
if b < 0:
a += step
continue
if abs(Y) > b:
a += step
continue
if (Y + b) % 2 != 0:
a += step
continue
l = (Y + b) // 2
if l < 0 or l > b:
a += step
continue
k = (X + a) // 2
if k < 0 or k > a:
a += step
continue
c = comb(N, a)
c_a = comb(a, k)
c_b = comb(b, l)
total = (total + c * c_a % MOD * c_b) % MOD
a += step
inv_4 = pow(4, MOD-2, MOD)
inv_4n = pow(inv_4, N, MOD)
ans = total * inv_4n % MOD
print(ans)
if __name__ == '__main__':
main()
lam6er