結果

問題 No.2062 Sum of Subset mod 999630629
ユーザー qwewe
提出日時 2025-05-14 12:56:18
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 10,406 bytes
コンパイル時間 217 ms
コンパイル使用メモリ 82,724 KB
実行使用メモリ 127,024 KB
最終ジャッジ日時 2025-05-14 12:58:34
合計ジャッジ時間 29,542 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 22 TLE * 1 -- * 6
権限があれば一括ダウンロードができます

ソースコード

diff #

# -*- coding: utf-8 -*-
import sys
import collections

# Use optimized input reading for speed in competitive programming
input = sys.stdin.readline

def solve():
    N = int(input())
    # Read the integer sequence A
    # A = list(map(int, input().split())) # Standard input read

    # Slightly optimized input reading for large N
    A = []
    raw_A = input().split()
    for x_str in raw_A:
        A.append(int(x_str))

    # Define the two primes given in the problem
    P1 = 999630629 # Modulo for f(S)
    P2 = 998244353 # Modulo for the final answer

    # The final answer should be modulo P2
    MOD = P2 
    
    # Calculate SumA = sum(A_i). Python integers handle arbitrary size.
    sum_A = 0
    for x in A:
        sum_A += x

    # Calculate 2^(N-1) mod P2 using modular exponentiation
    # Constraints state N >= 1
    pow2_N_minus_1 = pow(2, N - 1, MOD)

    # Calculate the first term T = (2^(N-1) * SumA) mod P2
    # We need SumA mod P2 for this calculation
    term1 = (pow2_N_minus_1 * (sum_A % MOD)) % MOD

    # The problem formula derived is R = (T - P1 * Count) mod P2
    # where Count is the number of non-empty subsets S such that sum(A_i for i in S) >= P1.
    # We showed that Count = sum_{k=0..K} c_k mod P2, where K = SumA - P1 and c_k is the number of subsets summing to k.
    
    count = 0 # Initialize count, default is 0 if SumA < P1

    # Only need to compute Count if SumA >= P1
    if sum_A >= P1:
        # Calculate K = SumA - P1. This is the maximum degree we need for the polynomial coefficients.
        K = sum_A - P1
        
        # Count frequencies of each distinct value in A
        counts = {}
        for x in A:
            # Check A_i >= 1 per constraints. If 0 were allowed, they wouldn't affect sum but increase subset count.
            if x > 0: 
              counts[x] = counts.get(x, 0) + 1
        
        # If there are no positive values in A (e.g. A=[0,0]), SumA=0. This block wouldn't be reached if P1>=1.
        # If counts is empty, then there are no elements to form subsets.

        # Precompute factorials and inverse factorials modulo P2 for calculating combinations (nCr)
        MAX_N_FOR_FACT = N # Maximum 'n' for nCr is N
        fact = [1] * (MAX_N_FOR_FACT + 1)
        invfact = [1] * (MAX_N_FOR_FACT + 1)
        for i in range(1, MAX_N_FOR_FACT + 1):
            fact[i] = (fact[i-1] * i) % MOD

        # Calculate modular inverse using Fermat's Little Theorem (since P2 is prime)
        invfact[MAX_N_FOR_FACT] = pow(fact[MAX_N_FOR_FACT], MOD - 2, MOD)
        # Compute remaining inverse factorials iteratively using invfact[i] = invfact[i+1] * (i+1)
        for i in range(MAX_N_FOR_FACT - 1, -1, -1):
             invfact[i] = (invfact[i+1] * (i+1)) % MOD

        # Function to compute nCr mod P2 using precomputed factorials
        def nCr_mod(n, r):
            if r < 0 or r > n:
                return 0
            # Safety check, should not happen with correct N
            if n > MAX_N_FOR_FACT: 
                 raise ValueError("n is too large for precomputed factorials")
            
            # nCr = n! / (r! * (n-r)!)
            num = fact[n]
            den = (invfact[r] * invfact[n-r]) % MOD
            return (num * den) % MOD

        # Prepare the base polynomials Q_v(x) = (1+x^v)^m_v for each distinct value v with frequency m_v.
        # We only need coefficients up to degree K.
        polys = []
        for v in sorted(counts.keys()): # Iterate through distinct values sorted
            m_v = counts[v]
            
            # Maximum possible degree from this term is v * m_v. Cap at K.
            current_max_deg = min(K, v * m_v)
            # Initialize polynomial as list of coefficients [c0, c1, ..., c_max_deg]
            q_v = [0] * (current_max_deg + 1)
            
            has_higher_terms = False # Track if polynomial has terms other than constant 1
            # Calculate coefficients using binomial theorem: (1+x^v)^m_v = sum_{j=0..m_v} C(m_v, j) * (x^v)^j
            for j in range(m_v + 1):
                term_deg = v * j
                if term_deg > K: # If degree exceeds K, stop for this polynomial
                    break
                
                comb = nCr_mod(m_v, j) # Calculate C(m_v, j) mod P2
                if comb > 0: # Only store non-zero terms
                   q_v[term_deg] = comb
                   if term_deg > 0: # Check if we added a term x^k with k > 0
                       has_higher_terms = True 
            
            # Add the generated polynomial to the list if it's not just [0]
            # It will always have at least q_v[0]=1 unless m_v=0 which is not possible here.
            # Optimization: if q_v is just [1], meaning (1+x^v)^m_v contribution doesn't affect degrees up to K beyond constant term,
            # multiplying by 1 doesn't change anything. But handle it in multiplication logic instead.
            polys.append(q_v)

        # If there are no polynomials (e.g., N=0 or all A_i were 0), the product is 1.
        if not polys:
             final_poly = [1] 
        else:
            # Use Number Theoretic Transform (NTT) for efficient polynomial multiplication.
            # Standard NTT parameters for MOD = 998244353: primitive root G = 3.
            G = 3
            
            # NTT implementation
            def ntt(a, inv):
                n = len(a) # Length must be power of 2
                
                # Bit reversal permutation
                j = 0
                for i in range(1, n):
                    rev = n >> 1
                    while j >= rev:
                        j -= rev
                        rev >>= 1
                    j += rev
                    if i < j:
                        a[i], a[j] = a[j], a[i]

                # Butterfly operations
                k = 1
                while k < n:
                    # Precompute roots of unity w_k_base = G^((MOD-1)/(2k))
                    w_k_base = pow(G, (MOD - 1) // (2 * k), MOD)
                    if inv: # For inverse NTT, use inverse root
                        w_k_base = pow(w_k_base, MOD - 2, MOD)
                    
                    for i in range(0, n, 2 * k):
                        w = 1 # Current power of root of unity
                        for j in range(k):
                            idx1 = i + j
                            idx2 = i + j + k
                            x = a[idx1]
                            y = (a[idx2] * w) % MOD
                            a[idx1] = (x + y) % MOD # Combine
                            a[idx2] = (x - y + MOD) % MOD # Combine
                            w = (w * w_k_base) % MOD # Update root power
                    k *= 2
                    
                if inv: # Scale by 1/n for inverse NTT
                    n_inv = pow(n, MOD - 2, MOD)
                    for i in range(n):
                        a[i] = (a[i] * n_inv) % MOD

            # Function to multiply two polynomials p1, p2 using NTT, result truncated to degree max_deg
            def multiply_poly(p1, p2, max_deg):
                len1 = len(p1)
                len2 = len(p2)
                
                # Handle trivial cases (multiplication by 0 or 1) efficiently
                if len1 == 0 or (len1 == 1 and p1[0] == 0): return [0] * (min(1, max_deg + 1))
                if len2 == 0 or (len2 == 1 and p2[0] == 0): return [0] * (min(1, max_deg + 1))
                
                if len1 == 1 and p1[0] == 1: # If p1 is polynomial 1
                    res = p2[:max_deg+1] # Take p2 up to degree max_deg
                    res.extend([0] * (max_deg + 1 - len(res))) # Pad with zeros if shorter
                    return res
                if len2 == 1 and p2[0] == 1: # If p2 is polynomial 1
                    res = p1[:max_deg+1] # Take p1 up to degree max_deg
                    res.extend([0] * (max_deg + 1 - len(res))) # Pad with zeros if shorter
                    return res

                # Determine required length for NTT (must be power of 2 >= combined degree + 1)
                needed_len = len1 + len2 - 1
                target_len = 1
                while target_len < needed_len:
                    target_len <<= 1

                # Copy and Pad polynomials with zeros to target_len
                p1_pad = p1[:] + [0] * (target_len - len1)
                p2_pad = p2[:] + [0] * (target_len - len2)

                # Perform forward NTT on both polynomials
                ntt(p1_pad, False)
                ntt(p2_pad, False)

                # Pointwise multiplication in the NTT domain
                res_pad = [0] * target_len
                for i in range(target_len):
                    res_pad[i] = (p1_pad[i] * p2_pad[i]) % MOD

                # Perform inverse NTT to get result polynomial coefficients
                ntt(res_pad, True)

                # Truncate the result polynomial to maximum degree `max_deg` (which is K here)
                final_len = min(max_deg + 1, len(res_pad))
                return res_pad[:final_len]

            # Use a deque to manage polynomials for iterative divide-and-conquer multiplication
            # This avoids deep recursion stacks for large number of polynomials
            queue = collections.deque(polys)

            # Repeatedly multiply pairs of polynomials until only one remains
            while len(queue) > 1:
                p1 = queue.popleft()
                p2 = queue.popleft()
                # Multiply p1 and p2, keeping result up to degree K
                res = multiply_poly(p1, p2, K) 
                queue.append(res) # Add result back to queue
            
            # The final polynomial P(x) = prod Q_v(x) mod x^(K+1) is the last element in the queue
            final_poly = queue[0]

            # Calculate S_leq_K = sum of coefficients of the final polynomial (mod P2)
            S_leq_K = 0
            for coeff in final_poly:
                 S_leq_K = (S_leq_K + coeff) % MOD
            count = S_leq_K # This sum is the required Count

    # Calculate the final result R = (term1 - P1 * Count) mod P2
    # Ensure P1 is taken modulo P2 before multiplication
    term2 = ((P1 % MOD) * count) % MOD
    
    # Final result calculation, ensuring positive result
    result = (term1 - term2 + MOD) % MOD
    print(result)

# Execute the main function
solve()
0