結果
| 問題 |
No.2406 Difference of Coordinate Squared
|
| コンテスト | |
| ユーザー |
gew1fw
|
| 提出日時 | 2025-06-12 14:03:22 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 3,140 bytes |
| コンパイル時間 | 336 ms |
| コンパイル使用メモリ | 82,304 KB |
| 実行使用メモリ | 93,556 KB |
| 最終ジャッジ日時 | 2025-06-12 14:04:17 |
| 合計ジャッジ時間 | 6,228 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 47 WA * 8 |
ソースコード
import sys
import math
MOD = 998244353
def main():
N, M = map(int, sys.stdin.readline().split())
if M == 0:
# Special case, X^2 = Y^2
pass # Handle separately if needed, but current code can process it
# Precompute factorials and inverse factorials modulo MOD
max_n = N
fact = [1] * (max_n + 1)
for i in range(1, max_n + 1):
fact[i] = fact[i-1] * i % MOD
inv_fact = [1] * (max_n + 1)
inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD)
for i in range(max_n-1, -1, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
def comb(n, k):
if k < 0 or k > n:
return 0
return fact[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD
# Function to compute all (a, b) pairs such that a * b = M
def get_divisors(M):
if M == 0:
return set()
divisors = set()
# Find all positive divisors of |M|
m_abs = abs(M)
for i in range(1, int(math.isqrt(m_abs)) + 1):
if m_abs % i == 0:
divisors.add(i)
divisors.add(m_abs // i)
# Generate all possible a = d * s where s is ±1
a_list = []
for d in divisors:
a_list.append(d)
a_list.append(-d)
pairs = set()
for a in a_list:
if M % a != 0:
continue
b = M // a
pairs.add((a, b))
return pairs
pairs = get_divisors(M)
sum_total = 0
inv4_powN = pow(4, N, MOD)
inv4_powN = pow(inv4_powN, MOD-2, MOD) # since (1/4)^N = 4^(-N) mod MOD
for a, b in pairs:
if a % 2 != b % 2:
continue
s_x = (a + b) // 2
s_y = (b - a) // 2
# Check congruence conditions for k
if (s_x % 2) != ((N - s_y) % 2):
continue
t = s_x % 2
# Determine the range of k
L = max(abs(s_x), 0)
U = N - abs(s_y)
if U < L:
continue
# Find first_k and last_k
if L % 2 == t:
first_k = L
else:
first_k = L + 1
if first_k > U:
continue
if U % 2 == t:
last_k = U
else:
last_k = U - 1
if last_k < first_k:
continue
# Iterate over k from first_k to last_k, step 2
step = 2
for k in range(first_k, last_k + 1, step):
m = N - k
# Check x_plus and y_plus
x_plus = (k + s_x) // 2
if x_plus < 0 or x_plus > k:
continue
y_plus = (m + s_y) // 2
if y_plus < 0 or y_plus > m:
continue
# Compute combinations
c_n_k = comb(N, k)
c_k_x = comb(k, x_plus)
c_m_y = comb(m, y_plus)
term = c_n_k * c_k_x % MOD
term = term * c_m_y % MOD
sum_total = (sum_total + term) % MOD
# Multiply by inv4^N
result = sum_total * inv4_powN % MOD
print(result)
if __name__ == "__main__":
main()
gew1fw