結果
問題 |
No.2959 Dolls' Tea Party
|
ユーザー |
![]() |
提出日時 | 2025-05-14 13:12:02 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 11,488 bytes |
コンパイル時間 | 289 ms |
コンパイル使用メモリ | 81,792 KB |
実行使用メモリ | 178,068 KB |
最終ジャッジ日時 | 2025-05-14 13:13:43 |
合計ジャッジ時間 | 6,722 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 2 TLE * 1 -- * 30 |
ソースコード
import sys # Set higher recursion depth if needed, although iterative NTT avoids deep recursion. # sys.setrecursionlimit(2000) def solve(): MOD = 998244353 # Fast modular exponentiation: computes base^power % MOD def fast_pow(base, power): result = 1 while power > 0: if power % 2 == 1: result = (result * base) % MOD base = (base * base) % MOD power //= 2 return result # Modular inverse using Fermat's Little Theorem: computes a^(MOD-2) % MOD def inverse(a): # Ensure a is not 0 modulo MOD before computing inverse a %= MOD if a == 0: return 0 # Or raise error, depends on context. Here 0 inverse is 0. return fast_pow(a, MOD - 2) # Precompute factorials and inverse factorials up to max_k def precompute_factorials(max_k): fact = [1] * (max_k + 1) invfact = [1] * (max_k + 1) for i in range(1, max_k + 1): fact[i] = (fact[i-1] * i) % MOD # Compute inverse of max_k! using modular inverse invfact[max_k] = inverse(fact[max_k]) # Compute other inverse factorials iteratively using invfact[i] = invfact[i+1] * (i+1) for i in range(max_k - 1, -1, -1): invfact[i] = (invfact[i+1] * (i+1)) % MOD return fact, invfact # --- NTT implementation --- # Primitive root for MOD 998244353 is 3. # MOD = 119 * 2^23 + 1, supports NTT up to length 2^23. primitive_root = 3 max_len_pow2 = 23 # Precompute roots of unity and their inverses roots = [fast_pow(primitive_root, (MOD - 1) // (1 << k)) for k in range(max_len_pow2 + 1)] inv_roots = [inverse(r) for r in roots] # Number Theoretic Transform (NTT) function # Performs NTT in place using iterative Cooley-Tukey algorithm # Parameter 'inv' indicates whether to perform inverse NTT def ntt(poly, inv): N = len(poly) if N == 1: return poly # Calculate minimal power of 2 length logN = N.bit_length() - 1 # Bit-reversal permutation: arrange elements according to bit-reversed indices permuted_poly = [0] * N for i in range(N): rev_i = 0 for j in range(logN): if (i >> j) & 1: rev_i |= 1 << (logN - 1 - j) permuted_poly[rev_i] = poly[i] # Select roots based on forward or inverse transform current_roots = inv_roots if inv else roots # Iterative Cooley-Tukey butterfly operations len_ = 2 # Current transform length while len_ <= N: # Root of unity for this level w_len = current_roots[len_.bit_length() - 1] half_len = len_ // 2 for i in range(0, N, len_): # Iterate through blocks w = 1 # Current power of root of unity for j in range(half_len): # Butterfly operations within block u = permuted_poly[i + j] v = (w * permuted_poly[i + j + half_len]) % MOD permuted_poly[i + j] = (u + v) % MOD permuted_poly[i + j + half_len] = (u - v + MOD) % MOD # Add MOD to ensure non-negative result w = (w * w_len) % MOD # Update power of root len_ <<= 1 # Double transform length # If inverse NTT, divide by N if inv: inv_N = inverse(N) for i in range(N): permuted_poly[i] = (permuted_poly[i] * inv_N) % MOD return permuted_poly # Polynomial multiplication using NTT # Multiplies poly1 and poly2, returns result truncated to degree_limit def ntt_multiply(poly1, poly2, degree_limit): len1 = len(poly1) len2 = len(poly2) # Handle trivial cases: multiplication by zero polynomial if min(len1, len2) == 0 or (len1 == 1 and poly1[0] == 0) or (len2 == 1 and poly2[0] == 0): # Return a zero polynomial of the correct truncated length return [0]*(degree_limit+1) # Optimization: multiplication by constant polynomial 1 if len1 == 1 and poly1[0] == 1: res = poly2[:degree_limit+1] # Take up to degree limit # Pad with zeros if shorter than required length if len(res) < degree_limit + 1: res.extend([0]*(degree_limit + 1 - len(res))) return res if len2 == 1 and poly2[0] == 1: res = poly1[:degree_limit+1] if len(res) < degree_limit + 1: res.extend([0]*(degree_limit + 1 - len(res))) return res # Calculate degree of the result polynomial # Degree is sum of degrees. Length is degree + 1. # Degree of poly of length L is L-1. res_degree = (len1 - 1) + (len2 - 1) # Handle cases where one or both polynomials are constants (degree -1 effectively) if res_degree < 0: res_degree = 0 # Determine NTT length: smallest power of 2 >= required length for result target_len = 1 # Required length is degree + 1 = (len1-1 + len2-1) + 1 = len1 + len2 - 1 while target_len < len1 + len2 - 1: target_len <<= 1 # Ensure minimum length of 1 if target_len == 0: target_len = 1 # Pad polynomials with zeros to target length pad1 = poly1 + [0] * (target_len - len1) pad2 = poly2 + [0] * (target_len - len2) # Perform forward NTT ntt1 = ntt(pad1, False) ntt2 = ntt(pad2, False) # Pointwise multiplication in frequency domain ntt_res = [(ntt1[i] * ntt2[i]) % MOD for i in range(target_len)] # Perform inverse NTT res_poly = ntt(ntt_res, True) # Truncate result polynomial to the specified degree_limit final_len = min(res_degree, degree_limit) + 1 final_poly = res_poly[:final_len] # Pad with zeros if the resulting polynomial is shorter than required length if len(final_poly) < final_len: final_poly.extend([0]*(final_len - len(final_poly))) return final_poly # Polynomial exponentiation using modular squaring # Computes poly^k % MOD, truncated to degree_limit def poly_pow(poly, k, degree_limit): # Initialize result polynomial as 1 (coefficient list [1]) padded to degree_limit res = [1] + [0]*degree_limit # Make a copy of the base polynomial and pad/truncate it base = poly[:] if len(base) < degree_limit + 1: base.extend([0]*(degree_limit+1 - len(base))) else: base = base[:degree_limit+1] # Exponentiation by squaring loop while k > 0: if k % 2 == 1: # If k is odd, multiply result by current base res = ntt_multiply(res, base, degree_limit) # Square the base for the next iteration (if k > 1) if k > 1: # Avoid squaring unnecessarily at the end base = ntt_multiply(base, base, degree_limit) k //= 2 # Return result truncated to degree_limit + 1 length return res[:degree_limit+1] # Euler's totient function phi(m) def get_phi(m): if m == 1: return 1 res = m p = 2 temp_m = m # Iterate through potential prime factors up to sqrt(m) while p * p <= temp_m: if temp_m % p == 0: # Found a prime factor p while temp_m % p == 0: # Remove all occurrences of p temp_m //= p # Apply the formula component for p: res = res * (1 - 1/p) = res - res/p res -= res // p p += 1 # If temp_m > 1 after the loop, it must be a prime factor itself if temp_m > 1: res -= res // temp_m return res # Returns integer value # --- Main Program Logic --- # Read input N and K N, K = map(int, sys.stdin.readline().split()) # Read doll counts A_i A = list(map(int, sys.stdin.readline().split())) # Precompute factorials and their inverses up to K fact, invfact = precompute_factorials(K) # Find all divisors of K divisors = [] for i in range(1, int(K**0.5) + 1): if K % i == 0: divisors.append(i) if i*i != K: # Avoid adding the square root twice divisors.append(K//i) total_sum = 0 # Accumulator for Burnside's Lemma sum # Cache for computed base polynomials P_{limit, d}(x) = sum_{p=0..limit} x^p/p! # Key is 'limit'. Helps avoid recomputing same polynomial within loop over d. base_poly_cache = {} # Iterate through each divisor d of K for d in divisors: # Calculate cycle length L = K/d L = K // d # Basic sanity check: L should be positive since K>=2, d>=1 if L == 0: continue # Count occurrences of each value B_i = floor(A_i / L) counts = {} for x in A: B_i = x // L if B_i > 0: # Only consider B_i > 0; P(x)=1 for B_i=0 counts[B_i] = counts.get(B_i, 0) + 1 # Initialize the product polynomial G_d(x) as 1. Pad to length d+1. final_poly = [1] + [0]*d # Get distinct positive B_i values that occurred distinct_B = sorted(counts.keys()) # Compute the product G_d(x) = product_{b} (P_{b,d}(x))^{counts[b]} for b in distinct_B: count = counts[b] # Number of times this B_i value occurred # Determine the degree limit for this base polynomial: min(b, d) limit = min(b, d) # Retrieve or compute the base polynomial P_{limit, d}(x) if limit not in base_poly_cache: # Compute P_{limit, d}(x) = sum_{p=0..limit} x^p/p! using precomputed inverse factorials current_poly_coeffs = [] for p in range(limit + 1): current_poly_coeffs.append(invfact[p]) base_poly_cache[limit] = current_poly_coeffs # Store in cache current_poly = base_poly_cache[limit] # Get from cache # Compute (P_{limit, d}(x))^count using polynomial exponentiation powered_poly = poly_pow(current_poly, count, d) # Multiply this into the accumulated product polynomial G_d(x) final_poly = ntt_multiply(final_poly, powered_poly, d) # Extract the coefficient of x^d from the final product polynomial G_d(x) coeff_xd = 0 if len(final_poly) > d: # Ensure the list is long enough coeff_xd = final_poly[d] # Compute F(d) = (coefficient of x^d / d!) * d! = (coefficient of x^d) * d! # Note: F(d) is defined as the coefficient of x^d/d! in the EGF product. # This simplifies to F(d) = [x^d]G_d(x) * d! F_d = (coeff_xd * fact[d]) % MOD # Compute Euler's totient function phi(K/d) phi_val = get_phi(K // d) # Add the term phi(K/d) * F(d) to the total sum (using Burnside's Lemma structure) term = (phi_val * F_d) % MOD total_sum = (total_sum + term) % MOD # Final answer is (1/K) * total_sum modulo MOD K_inv = inverse(K) # Compute modular inverse of K ans = (total_sum * K_inv) % MOD print(ans) # Output the final result solve()