結果
| 問題 |
No.2406 Difference of Coordinate Squared
|
| コンテスト | |
| ユーザー |
gew1fw
|
| 提出日時 | 2025-06-12 16:03:03 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 3,244 bytes |
| コンパイル時間 | 241 ms |
| コンパイル使用メモリ | 82,500 KB |
| 実行使用メモリ | 99,844 KB |
| 最終ジャッジ日時 | 2025-06-12 16:03:08 |
| 合計ジャッジ時間 | 4,022 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 1 TLE * 1 -- * 53 |
ソースコード
import sys
import math
MOD = 998244353
def main():
N, M = map(int, sys.stdin.readline().split())
max_n = N
fact = [1] * (max_n + 1)
for i in range(1, max_n + 1):
fact[i] = fact[i-1] * i % MOD
inv_fact = [1] * (max_n + 1)
inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD)
for i in range(max_n-1, -1, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
inv_2 = pow(2, MOD-2, MOD)
inv_2_pows = pow(inv_2, 2*N, MOD)
def comb(n, k):
if k < 0 or k > n:
return 0
return fact[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD
def get_divisors(m):
divisors = set()
if m == 0:
return divisors
m_abs = abs(m)
for i in range(1, int(math.isqrt(m_abs)) + 1):
if m_abs % i == 0:
divisors.add(i)
divisors.add(m_abs // i)
divisors.add(m_abs)
res = set()
for d in divisors:
res.add(d)
res.add(-d)
return res
answer = 0
if M == 0:
for a in range(N+1):
b = N - a
sum_x_sq = 0
sum_y_sq = 0
for s in range(-a, a+1, 2):
cnt_x = comb(a, (a + s) // 2)
if cnt_x == 0:
continue
for t in [-s, s]:
if abs(t) > b or (b - t) % 2 != 0:
continue
cnt_y = comb(b, (b + t) // 2)
if cnt_y == 0:
continue
term = comb(N, a)
term = term * cnt_x % MOD
term = term * cnt_y % MOD
term = term * inv_2_pows % MOD
answer = (answer + term) % MOD
print(answer)
return
divisors = get_divisors(M)
required_parity = N % 2
for A in divisors:
B = M // A
if (A % 2 != required_parity) or (B % 2 != required_parity):
continue
X = (A + B) // 2
Y = (B - A) // 2
if (A + B) % 2 != 0 or (B - A) % 2 != 0:
continue
abs_X = abs(X)
abs_Y = abs(Y)
min_a = abs_X
max_a = N - abs_Y
if min_a > max_a:
continue
a_parity = X % 2
start_a = min_a if (min_a % 2 == a_parity) else (min_a + 1)
if start_a > max_a:
continue
num_steps = (max_a - start_a) // 2 + 1
for k in range(num_steps):
a = start_a + 2 * k
b_val = N - a
if b_val < abs_Y:
continue
if (b_val % 2) != (Y % 2):
continue
c_n_a = comb(N, a)
if c_n_a == 0:
continue
x = X
y = Y
c_a = comb(a, (a + x) // 2)
c_b = comb(b_val, (b_val + y) // 2)
if c_a == 0 or c_b == 0:
continue
term = c_n_a * c_a % MOD
term = term * c_b % MOD
term = term * inv_2_pows % MOD
answer = (answer + term) % MOD
print(answer % MOD)
if __name__ == '__main__':
main()
gew1fw