結果

問題 No.2406 Difference of Coordinate Squared
ユーザー qwewe
提出日時 2025-05-14 13:04:57
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 7,526 bytes
コンパイル時間 210 ms
コンパイル使用メモリ 82,288 KB
実行使用メモリ 89,656 KB
最終ジャッジ日時 2025-05-14 13:06:34
合計ジャッジ時間 5,216 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 40 WA * 15
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

# Set higher recursion depth if necessary, although unlikely needed for this problem
# sys.setrecursionlimit(2000) 

def solve():
    N, M = map(int, sys.stdin.readline().split())
    MOD = 998244353

    # Precompute factorials and inverse factorials modulo MOD
    # Factorials: fact[i] = i! mod MOD
    fact = [1] * (N + 1)
    # Inverse factorials: invfact[i] = (i!)^(-1) mod MOD
    invfact = [1] * (N + 1)
    for i in range(1, N + 1):
        fact[i] = (fact[i - 1] * i) % MOD

    # Compute inverse factorial of N! using Fermat's Little Theorem (since MOD is prime)
    # (N!)^(MOD-2) mod MOD = (N!)^(-1) mod MOD
    invfact[N] = pow(fact[N], MOD - 2, MOD)
    # Compute other inverse factorials iteratively using invfact[i] = invfact[i+1] * (i+1)
    for i in range(N - 1, -1, -1):
        invfact[i] = (invfact[i + 1] * (i + 1)) % MOD

    # Function to compute nCr modulo MOD using precomputed values
    def nCr_mod(n, r):
        # Standard checks for combinations validity
        if r < 0 or r > n:
            return 0
        # Check if indices are within bounds of precomputed tables. 
        # n must be <= N for our precomputation.
        # r and n-r must be >= 0 which is covered by the first check.
        if n > N: 
             # This case should not happen if called correctly with N as max value.
             # If it somehow occurs, maybe due to invalid k_u or k_v calculation, return 0.
             return 0
        
        # Calculate nCr = n! / (r! * (n-r)!) mod MOD
        # Use modular inverse: nCr = n! * (r!)^(-1) * ((n-r)!)^(-1) mod MOD
        num = fact[n]
        # Denominator inverse part
        den = (invfact[r] * invfact[n - r]) % MOD
        return (num * den) % MOD

    # Calculate modular inverse of 4
    inv4 = pow(4, MOD - 2, MOD)
    # Calculate (1/4)^N mod MOD using modular exponentiation
    inv4_N = pow(inv4, N, MOD)

    # Handle M = 0 case separately
    # This corresponds to UV = 0, which means U=0 or V=0.
    # The probability is P(U=0) + P(V=0) - P(U=0, V=0)
    if M == 0:
        # If N is odd, both U and V must be odd (since U=sum +/-1 N times).
        # Thus P(U=0) = P(V=0) = 0. Probability is 0.
        if N % 2 != 0:
            print(0)
            return

        # If N is even, U and V can be 0.
        # P(U=0) = binom(N, N/2) * (1/2)^N
        # P(V=0) = binom(N, N/2) * (1/2)^N
        # P(U=0, V=0) = binom(N, N/2) * binom(N, N/2) * (1/4)^N
        # The formula derived is: binom(N, N/2) * 2^(1-N) - binom(N, N/2)^2 * (1/4)^N
        
        term1_coeff = nCr_mod(N, N // 2)
        
        # Calculate 2^(1-N) mod P = 2 * (2^(-1))^N mod P
        inv2 = pow(2, MOD - 2, MOD)
        if N == 0: # Special case N=0, 2^(1-0) = 2
           term1_val_mod = 2
        else: # N > 0
           term1_val_mod = (2 * pow(inv2, N, MOD)) % MOD
        
        term1 = (term1_coeff * term1_val_mod) % MOD

        # Calculate (binom{N}{N/2})^2 * (1/4)^N mod P
        term2_coeff = (term1_coeff * term1_coeff) % MOD
        term2_val_mod = inv4_N
        term2 = (term2_coeff * term2_val_mod) % MOD
        
        # Final probability for M=0 using inclusion-exclusion principle.
        # Add MOD before taking modulo to handle potential negative result
        ans = (term1 - term2 + MOD) % MOD
        print(ans)
        return

    # Handle M != 0 case
    # We need to sum P(U=u, V=v) for all pairs (u, v) such that uv=M
    # P(U=u, V=v) = binom(N, (N+u)/2) * binom(N, (N+v)/2) * (1/4)^N
    # We need to sum the coefficient part: binom(N, k_u) * binom(N, k_v)
    total_sum_coeffs = 0
    
    abs_M = abs(M)
    
    # Iterate through possible absolute values of u up to floor(sqrt(|M|))
    # This covers all divisor pairs (u_abs, v_abs) where |M| = u_abs * v_abs
    sqrt_M = 0
    if abs_M > 0: 
         sqrt_M = int(abs_M**0.5)
         # Adjust sqrt_M check for edge cases like perfect squares or floating point inaccuracy
         while (sqrt_M + 1) * (sqrt_M + 1) <= abs_M:
             sqrt_M += 1
         while sqrt_M * sqrt_M > abs_M:
             sqrt_M -=1

    # Keep track of u values processed to avoid double counting pairs like (u,v) and (v,u) if |u|!=|v|
    # e.g. for M=12, sqrt=3. u_abs=1, v_abs=12; u_abs=2, v_abs=6; u_abs=3, v_abs=4
    # We need to consider divisors u=+/-1, +/-2, +/-3, +/-4, +/-6, +/-12
    processed_u = set() 

    for u_abs in range(1, sqrt_M + 1):
        if abs_M % u_abs == 0:
            v_abs = abs_M // u_abs
            
            # Get the potential absolute values for u based on this divisor pair
            possible_u_abs = [u_abs]
            # If M is not a perfect square, u_abs != v_abs, so add v_abs too
            if u_abs * u_abs != abs_M:
                 possible_u_abs.append(v_abs)

            for cur_u_abs in possible_u_abs:
                # For each absolute value, consider both positive and negative u
                for u in [cur_u_abs, -cur_u_abs]:
                    # Skip if u=0 (M!=0 means u cannot be 0) or if already processed
                    if u == 0: continue 
                    if u in processed_u: continue 
                    processed_u.add(u)

                    # Corresponding v such that uv = M
                    # Integer division works since u must divide M by construction
                    v = M // u 

                    # Check necessary conditions for P(U=u, V=v) to be non-zero and valid
                    valid = True
                    # Condition 1: |u| <= N (coordinate cannot exceed N steps away)
                    if not (abs(u) <= N): valid = False
                    # Condition 2: |v| <= N (same reason)
                    if not (abs(v) <= N): valid = False
                    # Condition 3: Parity check: N and u must have the same parity.
                    # (N+u) must be even for k_u = (N+u)/2 to be integer.
                    # Equivalently, N-u must be even.
                    # Note: If N, u have same parity, and uv=M, then N, v must also have same parity
                    # Because u, v must have same parity if uv=M is odd, and different parity if M is even
                    # NO, u,v must have same parity if M is odd. If M is even, they can have any parity combination except odd/odd.
                    # The requirement is N=u (mod 2) AND N=v (mod 2). This implies u=v (mod 2).
                    # If N=u (mod 2) is satisfied, then u=v (mod 2) must hold for N=v (mod 2) to hold.
                    # If uv=M, u=v (mod 2) unless M is even and one is even, one is odd.
                    # But P(U=u, V=v) is non-zero only if N=u (mod 2) AND N=v (mod 2). This requires u=v (mod 2).
                    # So we only need to check N=u (mod 2).
                    if (N - u) % 2 != 0: valid = False

                    if valid:
                         # Calculate indices for binomial coefficients
                         # These must be integers because of parity check
                         k_u = (N + u) // 2
                         k_v = (N + v) // 2
                         
                         # Compute the coefficient term binom(N, k_u) * binom(N, k_v) mod P
                         term = (nCr_mod(N, k_u) * nCr_mod(N, k_v)) % MOD
                         # Add to the total sum of coefficients
                         total_sum_coeffs = (total_sum_coeffs + term) % MOD

    # Final answer is the total sum of coefficients multiplied by (1/4)^N mod P
    final_ans = (total_sum_coeffs * inv4_N) % MOD
    print(final_ans)

solve()
0