結果
| 問題 |
No.2406 Difference of Coordinate Squared
|
| コンテスト | |
| ユーザー |
gew1fw
|
| 提出日時 | 2025-06-12 16:03:22 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 2,903 bytes |
| コンパイル時間 | 329 ms |
| コンパイル使用メモリ | 82,532 KB |
| 実行使用メモリ | 93,312 KB |
| 最終ジャッジ日時 | 2025-06-12 16:03:43 |
| 合計ジャッジ時間 | 6,600 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 47 WA * 8 |
ソースコード
MOD = 998244353
def main():
import sys
N, M = map(int, sys.stdin.readline().split())
max_n = N
fact = [1] * (max_n + 1)
for i in range(1, max_n + 1):
fact[i] = fact[i-1] * i % MOD
inv_fact = [1] * (max_n + 1)
inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD)
for i in range(max_n-1, -1, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
inv_2_2N = pow(4, N, MOD)
inv_2_2N = pow(inv_2_2N, MOD-2, MOD)
def comb(n, k):
if k < 0 or k > n:
return 0
return fact[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD
def get_factors(m):
if m == 0:
return []
m_abs = abs(m)
factors = {}
i = 2
while i * i <= m_abs:
while m_abs % i == 0:
factors[i] = factors.get(i, 0) + 1
m_abs //= i
i += 1
if m_abs > 1:
factors[m_abs] = factors.get(m_abs, 0) + 1
divisors = [1]
for p, exp in factors.items():
temp = []
for d in divisors:
current = d
for e in range(1, exp + 1):
current *= p
temp.append(current)
divisors += temp
divisors = list(set(divisors))
all_divisors = []
for d in divisors:
all_divisors.append(d)
all_divisors.append(-d)
all_divisors = list(set(all_divisors))
factor_pairs = []
for a in all_divisors:
if a == 0:
continue
if m % a != 0:
continue
b = m // a
if (a % 2) == (b % 2):
factor_pairs.append((a, b))
return factor_pairs
factor_pairs = get_factors(M)
total = 0
seen = set()
for a, b in factor_pairs:
if (a, b) in seen:
continue
seen.add((a, b))
x = (a + b) // 2
z = (b - a) // 2
if (x + (N - z)) % 2 != 0:
continue
k_min = max(abs(x), 0)
k_max = N - abs(z)
if k_min > k_max:
continue
parity = x % 2
start = k_min if (k_min % 2 == parity) else k_min + 1
if start > k_max:
continue
step = 2
num_terms = ((k_max - start) // step) + 1
for k in range(start, k_max + 1, step):
l = N - k
if l < abs(z):
continue
a_x = (x + k) // 2
b_z = (z + l) // 2
c_nk = comb(N, k)
c_ka = comb(k, a_x)
c_lb = comb(l, b_z)
term = c_nk * c_ka % MOD
term = term * c_lb % MOD
total = (total + term) % MOD
total = total * inv_2_2N % MOD
print(total)
if __name__ == '__main__':
main()
gew1fw