結果
問題 |
No.2406 Difference of Coordinate Squared
|
ユーザー |
![]() |
提出日時 | 2025-05-14 13:04:57 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 7,526 bytes |
コンパイル時間 | 210 ms |
コンパイル使用メモリ | 82,288 KB |
実行使用メモリ | 89,656 KB |
最終ジャッジ日時 | 2025-05-14 13:06:34 |
合計ジャッジ時間 | 5,216 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 40 WA * 15 |
ソースコード
import sys # Set higher recursion depth if necessary, although unlikely needed for this problem # sys.setrecursionlimit(2000) def solve(): N, M = map(int, sys.stdin.readline().split()) MOD = 998244353 # Precompute factorials and inverse factorials modulo MOD # Factorials: fact[i] = i! mod MOD fact = [1] * (N + 1) # Inverse factorials: invfact[i] = (i!)^(-1) mod MOD invfact = [1] * (N + 1) for i in range(1, N + 1): fact[i] = (fact[i - 1] * i) % MOD # Compute inverse factorial of N! using Fermat's Little Theorem (since MOD is prime) # (N!)^(MOD-2) mod MOD = (N!)^(-1) mod MOD invfact[N] = pow(fact[N], MOD - 2, MOD) # Compute other inverse factorials iteratively using invfact[i] = invfact[i+1] * (i+1) for i in range(N - 1, -1, -1): invfact[i] = (invfact[i + 1] * (i + 1)) % MOD # Function to compute nCr modulo MOD using precomputed values def nCr_mod(n, r): # Standard checks for combinations validity if r < 0 or r > n: return 0 # Check if indices are within bounds of precomputed tables. # n must be <= N for our precomputation. # r and n-r must be >= 0 which is covered by the first check. if n > N: # This case should not happen if called correctly with N as max value. # If it somehow occurs, maybe due to invalid k_u or k_v calculation, return 0. return 0 # Calculate nCr = n! / (r! * (n-r)!) mod MOD # Use modular inverse: nCr = n! * (r!)^(-1) * ((n-r)!)^(-1) mod MOD num = fact[n] # Denominator inverse part den = (invfact[r] * invfact[n - r]) % MOD return (num * den) % MOD # Calculate modular inverse of 4 inv4 = pow(4, MOD - 2, MOD) # Calculate (1/4)^N mod MOD using modular exponentiation inv4_N = pow(inv4, N, MOD) # Handle M = 0 case separately # This corresponds to UV = 0, which means U=0 or V=0. # The probability is P(U=0) + P(V=0) - P(U=0, V=0) if M == 0: # If N is odd, both U and V must be odd (since U=sum +/-1 N times). # Thus P(U=0) = P(V=0) = 0. Probability is 0. if N % 2 != 0: print(0) return # If N is even, U and V can be 0. # P(U=0) = binom(N, N/2) * (1/2)^N # P(V=0) = binom(N, N/2) * (1/2)^N # P(U=0, V=0) = binom(N, N/2) * binom(N, N/2) * (1/4)^N # The formula derived is: binom(N, N/2) * 2^(1-N) - binom(N, N/2)^2 * (1/4)^N term1_coeff = nCr_mod(N, N // 2) # Calculate 2^(1-N) mod P = 2 * (2^(-1))^N mod P inv2 = pow(2, MOD - 2, MOD) if N == 0: # Special case N=0, 2^(1-0) = 2 term1_val_mod = 2 else: # N > 0 term1_val_mod = (2 * pow(inv2, N, MOD)) % MOD term1 = (term1_coeff * term1_val_mod) % MOD # Calculate (binom{N}{N/2})^2 * (1/4)^N mod P term2_coeff = (term1_coeff * term1_coeff) % MOD term2_val_mod = inv4_N term2 = (term2_coeff * term2_val_mod) % MOD # Final probability for M=0 using inclusion-exclusion principle. # Add MOD before taking modulo to handle potential negative result ans = (term1 - term2 + MOD) % MOD print(ans) return # Handle M != 0 case # We need to sum P(U=u, V=v) for all pairs (u, v) such that uv=M # P(U=u, V=v) = binom(N, (N+u)/2) * binom(N, (N+v)/2) * (1/4)^N # We need to sum the coefficient part: binom(N, k_u) * binom(N, k_v) total_sum_coeffs = 0 abs_M = abs(M) # Iterate through possible absolute values of u up to floor(sqrt(|M|)) # This covers all divisor pairs (u_abs, v_abs) where |M| = u_abs * v_abs sqrt_M = 0 if abs_M > 0: sqrt_M = int(abs_M**0.5) # Adjust sqrt_M check for edge cases like perfect squares or floating point inaccuracy while (sqrt_M + 1) * (sqrt_M + 1) <= abs_M: sqrt_M += 1 while sqrt_M * sqrt_M > abs_M: sqrt_M -=1 # Keep track of u values processed to avoid double counting pairs like (u,v) and (v,u) if |u|!=|v| # e.g. for M=12, sqrt=3. u_abs=1, v_abs=12; u_abs=2, v_abs=6; u_abs=3, v_abs=4 # We need to consider divisors u=+/-1, +/-2, +/-3, +/-4, +/-6, +/-12 processed_u = set() for u_abs in range(1, sqrt_M + 1): if abs_M % u_abs == 0: v_abs = abs_M // u_abs # Get the potential absolute values for u based on this divisor pair possible_u_abs = [u_abs] # If M is not a perfect square, u_abs != v_abs, so add v_abs too if u_abs * u_abs != abs_M: possible_u_abs.append(v_abs) for cur_u_abs in possible_u_abs: # For each absolute value, consider both positive and negative u for u in [cur_u_abs, -cur_u_abs]: # Skip if u=0 (M!=0 means u cannot be 0) or if already processed if u == 0: continue if u in processed_u: continue processed_u.add(u) # Corresponding v such that uv = M # Integer division works since u must divide M by construction v = M // u # Check necessary conditions for P(U=u, V=v) to be non-zero and valid valid = True # Condition 1: |u| <= N (coordinate cannot exceed N steps away) if not (abs(u) <= N): valid = False # Condition 2: |v| <= N (same reason) if not (abs(v) <= N): valid = False # Condition 3: Parity check: N and u must have the same parity. # (N+u) must be even for k_u = (N+u)/2 to be integer. # Equivalently, N-u must be even. # Note: If N, u have same parity, and uv=M, then N, v must also have same parity # Because u, v must have same parity if uv=M is odd, and different parity if M is even # NO, u,v must have same parity if M is odd. If M is even, they can have any parity combination except odd/odd. # The requirement is N=u (mod 2) AND N=v (mod 2). This implies u=v (mod 2). # If N=u (mod 2) is satisfied, then u=v (mod 2) must hold for N=v (mod 2) to hold. # If uv=M, u=v (mod 2) unless M is even and one is even, one is odd. # But P(U=u, V=v) is non-zero only if N=u (mod 2) AND N=v (mod 2). This requires u=v (mod 2). # So we only need to check N=u (mod 2). if (N - u) % 2 != 0: valid = False if valid: # Calculate indices for binomial coefficients # These must be integers because of parity check k_u = (N + u) // 2 k_v = (N + v) // 2 # Compute the coefficient term binom(N, k_u) * binom(N, k_v) mod P term = (nCr_mod(N, k_u) * nCr_mod(N, k_v)) % MOD # Add to the total sum of coefficients total_sum_coeffs = (total_sum_coeffs + term) % MOD # Final answer is the total sum of coefficients multiplied by (1/4)^N mod P final_ans = (total_sum_coeffs * inv4_N) % MOD print(final_ans) solve()