結果
問題 |
No.2406 Difference of Coordinate Squared
|
ユーザー |
![]() |
提出日時 | 2025-06-12 16:03:22 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 2,903 bytes |
コンパイル時間 | 329 ms |
コンパイル使用メモリ | 82,532 KB |
実行使用メモリ | 93,312 KB |
最終ジャッジ日時 | 2025-06-12 16:03:43 |
合計ジャッジ時間 | 6,600 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 47 WA * 8 |
ソースコード
MOD = 998244353 def main(): import sys N, M = map(int, sys.stdin.readline().split()) max_n = N fact = [1] * (max_n + 1) for i in range(1, max_n + 1): fact[i] = fact[i-1] * i % MOD inv_fact = [1] * (max_n + 1) inv_fact[max_n] = pow(fact[max_n], MOD-2, MOD) for i in range(max_n-1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD inv_2_2N = pow(4, N, MOD) inv_2_2N = pow(inv_2_2N, MOD-2, MOD) def comb(n, k): if k < 0 or k > n: return 0 return fact[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD def get_factors(m): if m == 0: return [] m_abs = abs(m) factors = {} i = 2 while i * i <= m_abs: while m_abs % i == 0: factors[i] = factors.get(i, 0) + 1 m_abs //= i i += 1 if m_abs > 1: factors[m_abs] = factors.get(m_abs, 0) + 1 divisors = [1] for p, exp in factors.items(): temp = [] for d in divisors: current = d for e in range(1, exp + 1): current *= p temp.append(current) divisors += temp divisors = list(set(divisors)) all_divisors = [] for d in divisors: all_divisors.append(d) all_divisors.append(-d) all_divisors = list(set(all_divisors)) factor_pairs = [] for a in all_divisors: if a == 0: continue if m % a != 0: continue b = m // a if (a % 2) == (b % 2): factor_pairs.append((a, b)) return factor_pairs factor_pairs = get_factors(M) total = 0 seen = set() for a, b in factor_pairs: if (a, b) in seen: continue seen.add((a, b)) x = (a + b) // 2 z = (b - a) // 2 if (x + (N - z)) % 2 != 0: continue k_min = max(abs(x), 0) k_max = N - abs(z) if k_min > k_max: continue parity = x % 2 start = k_min if (k_min % 2 == parity) else k_min + 1 if start > k_max: continue step = 2 num_terms = ((k_max - start) // step) + 1 for k in range(start, k_max + 1, step): l = N - k if l < abs(z): continue a_x = (x + k) // 2 b_z = (z + l) // 2 c_nk = comb(N, k) c_ka = comb(k, a_x) c_lb = comb(l, b_z) term = c_nk * c_ka % MOD term = term * c_lb % MOD total = (total + term) % MOD total = total * inv_2_2N % MOD print(total) if __name__ == '__main__': main()