結果

問題 No.2959 Dolls' Tea Party
ユーザー qwewe
提出日時 2025-05-14 13:12:02
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 11,488 bytes
コンパイル時間 289 ms
コンパイル使用メモリ 81,792 KB
実行使用メモリ 178,068 KB
最終ジャッジ日時 2025-05-14 13:13:43
合計ジャッジ時間 6,722 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 2 TLE * 1 -- * 30
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

# Set higher recursion depth if needed, although iterative NTT avoids deep recursion.
# sys.setrecursionlimit(2000) 

def solve():
    MOD = 998244353

    # Fast modular exponentiation: computes base^power % MOD
    def fast_pow(base, power):
        result = 1
        while power > 0:
            if power % 2 == 1:
                result = (result * base) % MOD
            base = (base * base) % MOD
            power //= 2
        return result

    # Modular inverse using Fermat's Little Theorem: computes a^(MOD-2) % MOD
    def inverse(a):
        # Ensure a is not 0 modulo MOD before computing inverse
        a %= MOD
        if a == 0: return 0 # Or raise error, depends on context. Here 0 inverse is 0.
        return fast_pow(a, MOD - 2)

    # Precompute factorials and inverse factorials up to max_k
    def precompute_factorials(max_k):
        fact = [1] * (max_k + 1)
        invfact = [1] * (max_k + 1)
        for i in range(1, max_k + 1):
            fact[i] = (fact[i-1] * i) % MOD
        
        # Compute inverse of max_k! using modular inverse
        invfact[max_k] = inverse(fact[max_k])
        # Compute other inverse factorials iteratively using invfact[i] = invfact[i+1] * (i+1)
        for i in range(max_k - 1, -1, -1):
            invfact[i] = (invfact[i+1] * (i+1)) % MOD
        return fact, invfact

    # --- NTT implementation ---
    # Primitive root for MOD 998244353 is 3.
    # MOD = 119 * 2^23 + 1, supports NTT up to length 2^23.
    primitive_root = 3
    max_len_pow2 = 23 
    # Precompute roots of unity and their inverses
    roots = [fast_pow(primitive_root, (MOD - 1) // (1 << k)) for k in range(max_len_pow2 + 1)]
    inv_roots = [inverse(r) for r in roots]

    # Number Theoretic Transform (NTT) function
    # Performs NTT in place using iterative Cooley-Tukey algorithm
    # Parameter 'inv' indicates whether to perform inverse NTT
    def ntt(poly, inv):
        N = len(poly)
        if N == 1:
            return poly

        # Calculate minimal power of 2 length
        logN = N.bit_length() - 1
        
        # Bit-reversal permutation: arrange elements according to bit-reversed indices
        permuted_poly = [0] * N
        for i in range(N):
            rev_i = 0
            for j in range(logN):
                if (i >> j) & 1:
                    rev_i |= 1 << (logN - 1 - j)
            permuted_poly[rev_i] = poly[i]

        # Select roots based on forward or inverse transform
        current_roots = inv_roots if inv else roots
        
        # Iterative Cooley-Tukey butterfly operations
        len_ = 2 # Current transform length
        while len_ <= N:
            # Root of unity for this level
            w_len = current_roots[len_.bit_length() - 1]
            
            half_len = len_ // 2
            for i in range(0, N, len_): # Iterate through blocks
                w = 1 # Current power of root of unity
                for j in range(half_len): # Butterfly operations within block
                    u = permuted_poly[i + j]
                    v = (w * permuted_poly[i + j + half_len]) % MOD
                    permuted_poly[i + j] = (u + v) % MOD
                    permuted_poly[i + j + half_len] = (u - v + MOD) % MOD # Add MOD to ensure non-negative result
                    w = (w * w_len) % MOD # Update power of root
            len_ <<= 1 # Double transform length

        # If inverse NTT, divide by N
        if inv:
            inv_N = inverse(N)
            for i in range(N):
                permuted_poly[i] = (permuted_poly[i] * inv_N) % MOD

        return permuted_poly

    # Polynomial multiplication using NTT
    # Multiplies poly1 and poly2, returns result truncated to degree_limit
    def ntt_multiply(poly1, poly2, degree_limit):
        len1 = len(poly1)
        len2 = len(poly2)
        
        # Handle trivial cases: multiplication by zero polynomial
        if min(len1, len2) == 0 or (len1 == 1 and poly1[0] == 0) or (len2 == 1 and poly2[0] == 0):
             # Return a zero polynomial of the correct truncated length
             return [0]*(degree_limit+1)
        
        # Optimization: multiplication by constant polynomial 1
        if len1 == 1 and poly1[0] == 1:
            res = poly2[:degree_limit+1] # Take up to degree limit
            # Pad with zeros if shorter than required length
            if len(res) < degree_limit + 1: res.extend([0]*(degree_limit + 1 - len(res)))
            return res
        if len2 == 1 and poly2[0] == 1:
            res = poly1[:degree_limit+1]
            if len(res) < degree_limit + 1: res.extend([0]*(degree_limit + 1 - len(res)))
            return res

        # Calculate degree of the result polynomial
        # Degree is sum of degrees. Length is degree + 1.
        # Degree of poly of length L is L-1.
        res_degree = (len1 - 1) + (len2 - 1)
        # Handle cases where one or both polynomials are constants (degree -1 effectively)
        if res_degree < 0: res_degree = 0 
        
        # Determine NTT length: smallest power of 2 >= required length for result
        target_len = 1
        # Required length is degree + 1 = (len1-1 + len2-1) + 1 = len1 + len2 - 1
        while target_len < len1 + len2 - 1:
             target_len <<= 1
        # Ensure minimum length of 1
        if target_len == 0: target_len = 1 

        # Pad polynomials with zeros to target length
        pad1 = poly1 + [0] * (target_len - len1)
        pad2 = poly2 + [0] * (target_len - len2)

        # Perform forward NTT
        ntt1 = ntt(pad1, False)
        ntt2 = ntt(pad2, False)

        # Pointwise multiplication in frequency domain
        ntt_res = [(ntt1[i] * ntt2[i]) % MOD for i in range(target_len)]
        
        # Perform inverse NTT
        res_poly = ntt(ntt_res, True)
        
        # Truncate result polynomial to the specified degree_limit
        final_len = min(res_degree, degree_limit) + 1
        final_poly = res_poly[:final_len]
        # Pad with zeros if the resulting polynomial is shorter than required length
        if len(final_poly) < final_len:
             final_poly.extend([0]*(final_len - len(final_poly)))

        return final_poly

    # Polynomial exponentiation using modular squaring
    # Computes poly^k % MOD, truncated to degree_limit
    def poly_pow(poly, k, degree_limit):
        # Initialize result polynomial as 1 (coefficient list [1]) padded to degree_limit
        res = [1] + [0]*degree_limit
        
        # Make a copy of the base polynomial and pad/truncate it
        base = poly[:]
        if len(base) < degree_limit + 1: base.extend([0]*(degree_limit+1 - len(base)))
        else: base = base[:degree_limit+1]

        # Exponentiation by squaring loop
        while k > 0:
            if k % 2 == 1: # If k is odd, multiply result by current base
                res = ntt_multiply(res, base, degree_limit)
            # Square the base for the next iteration (if k > 1)
            if k > 1: # Avoid squaring unnecessarily at the end
                 base = ntt_multiply(base, base, degree_limit)
            k //= 2
        
        # Return result truncated to degree_limit + 1 length
        return res[:degree_limit+1]

    # Euler's totient function phi(m)
    def get_phi(m):
        if m == 1: return 1
        res = m
        p = 2
        temp_m = m
        # Iterate through potential prime factors up to sqrt(m)
        while p * p <= temp_m:
            if temp_m % p == 0:
                # Found a prime factor p
                while temp_m % p == 0: # Remove all occurrences of p
                    temp_m //= p
                # Apply the formula component for p: res = res * (1 - 1/p) = res - res/p
                res -= res // p 
            p += 1
        # If temp_m > 1 after the loop, it must be a prime factor itself
        if temp_m > 1: 
            res -= res // temp_m
        return res # Returns integer value


    # --- Main Program Logic ---
    # Read input N and K
    N, K = map(int, sys.stdin.readline().split())
    # Read doll counts A_i
    A = list(map(int, sys.stdin.readline().split()))

    # Precompute factorials and their inverses up to K
    fact, invfact = precompute_factorials(K)

    # Find all divisors of K
    divisors = []
    for i in range(1, int(K**0.5) + 1):
        if K % i == 0:
            divisors.append(i)
            if i*i != K: # Avoid adding the square root twice
                divisors.append(K//i)
    
    total_sum = 0 # Accumulator for Burnside's Lemma sum
    
    # Cache for computed base polynomials P_{limit, d}(x) = sum_{p=0..limit} x^p/p!
    # Key is 'limit'. Helps avoid recomputing same polynomial within loop over d.
    base_poly_cache = {} 

    # Iterate through each divisor d of K
    for d in divisors:
        # Calculate cycle length L = K/d
        L = K // d
        # Basic sanity check: L should be positive since K>=2, d>=1
        if L == 0: continue 

        # Count occurrences of each value B_i = floor(A_i / L)
        counts = {}
        for x in A:
            B_i = x // L
            if B_i > 0: # Only consider B_i > 0; P(x)=1 for B_i=0
                 counts[B_i] = counts.get(B_i, 0) + 1

        # Initialize the product polynomial G_d(x) as 1. Pad to length d+1.
        final_poly = [1] + [0]*d 
        
        # Get distinct positive B_i values that occurred
        distinct_B = sorted(counts.keys())

        # Compute the product G_d(x) = product_{b} (P_{b,d}(x))^{counts[b]}
        for b in distinct_B:
            count = counts[b] # Number of times this B_i value occurred
            # Determine the degree limit for this base polynomial: min(b, d)
            limit = min(b, d)
            
            # Retrieve or compute the base polynomial P_{limit, d}(x)
            if limit not in base_poly_cache:
                 # Compute P_{limit, d}(x) = sum_{p=0..limit} x^p/p! using precomputed inverse factorials
                 current_poly_coeffs = []
                 for p in range(limit + 1):
                     current_poly_coeffs.append(invfact[p])
                 base_poly_cache[limit] = current_poly_coeffs # Store in cache
            
            current_poly = base_poly_cache[limit] # Get from cache
            
            # Compute (P_{limit, d}(x))^count using polynomial exponentiation
            powered_poly = poly_pow(current_poly, count, d) 
            # Multiply this into the accumulated product polynomial G_d(x)
            final_poly = ntt_multiply(final_poly, powered_poly, d) 

        # Extract the coefficient of x^d from the final product polynomial G_d(x)
        coeff_xd = 0
        if len(final_poly) > d: # Ensure the list is long enough
            coeff_xd = final_poly[d]
        
        # Compute F(d) = (coefficient of x^d / d!) * d! = (coefficient of x^d) * d!
        # Note: F(d) is defined as the coefficient of x^d/d! in the EGF product.
        # This simplifies to F(d) = [x^d]G_d(x) * d!
        F_d = (coeff_xd * fact[d]) % MOD

        # Compute Euler's totient function phi(K/d)
        phi_val = get_phi(K // d) 
        
        # Add the term phi(K/d) * F(d) to the total sum (using Burnside's Lemma structure)
        term = (phi_val * F_d) % MOD
        total_sum = (total_sum + term) % MOD
        
    # Final answer is (1/K) * total_sum modulo MOD
    K_inv = inverse(K) # Compute modular inverse of K
    ans = (total_sum * K_inv) % MOD
    print(ans) # Output the final result

solve()
0