結果
| 問題 |
No.2959 Dolls' Tea Party
|
| コンテスト | |
| ユーザー |
qwewe
|
| 提出日時 | 2025-05-14 13:12:02 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 11,488 bytes |
| コンパイル時間 | 289 ms |
| コンパイル使用メモリ | 81,792 KB |
| 実行使用メモリ | 178,068 KB |
| 最終ジャッジ日時 | 2025-05-14 13:13:43 |
| 合計ジャッジ時間 | 6,722 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 4 |
| other | AC * 2 TLE * 1 -- * 30 |
ソースコード
import sys
# Set higher recursion depth if needed, although iterative NTT avoids deep recursion.
# sys.setrecursionlimit(2000)
def solve():
MOD = 998244353
# Fast modular exponentiation: computes base^power % MOD
def fast_pow(base, power):
result = 1
while power > 0:
if power % 2 == 1:
result = (result * base) % MOD
base = (base * base) % MOD
power //= 2
return result
# Modular inverse using Fermat's Little Theorem: computes a^(MOD-2) % MOD
def inverse(a):
# Ensure a is not 0 modulo MOD before computing inverse
a %= MOD
if a == 0: return 0 # Or raise error, depends on context. Here 0 inverse is 0.
return fast_pow(a, MOD - 2)
# Precompute factorials and inverse factorials up to max_k
def precompute_factorials(max_k):
fact = [1] * (max_k + 1)
invfact = [1] * (max_k + 1)
for i in range(1, max_k + 1):
fact[i] = (fact[i-1] * i) % MOD
# Compute inverse of max_k! using modular inverse
invfact[max_k] = inverse(fact[max_k])
# Compute other inverse factorials iteratively using invfact[i] = invfact[i+1] * (i+1)
for i in range(max_k - 1, -1, -1):
invfact[i] = (invfact[i+1] * (i+1)) % MOD
return fact, invfact
# --- NTT implementation ---
# Primitive root for MOD 998244353 is 3.
# MOD = 119 * 2^23 + 1, supports NTT up to length 2^23.
primitive_root = 3
max_len_pow2 = 23
# Precompute roots of unity and their inverses
roots = [fast_pow(primitive_root, (MOD - 1) // (1 << k)) for k in range(max_len_pow2 + 1)]
inv_roots = [inverse(r) for r in roots]
# Number Theoretic Transform (NTT) function
# Performs NTT in place using iterative Cooley-Tukey algorithm
# Parameter 'inv' indicates whether to perform inverse NTT
def ntt(poly, inv):
N = len(poly)
if N == 1:
return poly
# Calculate minimal power of 2 length
logN = N.bit_length() - 1
# Bit-reversal permutation: arrange elements according to bit-reversed indices
permuted_poly = [0] * N
for i in range(N):
rev_i = 0
for j in range(logN):
if (i >> j) & 1:
rev_i |= 1 << (logN - 1 - j)
permuted_poly[rev_i] = poly[i]
# Select roots based on forward or inverse transform
current_roots = inv_roots if inv else roots
# Iterative Cooley-Tukey butterfly operations
len_ = 2 # Current transform length
while len_ <= N:
# Root of unity for this level
w_len = current_roots[len_.bit_length() - 1]
half_len = len_ // 2
for i in range(0, N, len_): # Iterate through blocks
w = 1 # Current power of root of unity
for j in range(half_len): # Butterfly operations within block
u = permuted_poly[i + j]
v = (w * permuted_poly[i + j + half_len]) % MOD
permuted_poly[i + j] = (u + v) % MOD
permuted_poly[i + j + half_len] = (u - v + MOD) % MOD # Add MOD to ensure non-negative result
w = (w * w_len) % MOD # Update power of root
len_ <<= 1 # Double transform length
# If inverse NTT, divide by N
if inv:
inv_N = inverse(N)
for i in range(N):
permuted_poly[i] = (permuted_poly[i] * inv_N) % MOD
return permuted_poly
# Polynomial multiplication using NTT
# Multiplies poly1 and poly2, returns result truncated to degree_limit
def ntt_multiply(poly1, poly2, degree_limit):
len1 = len(poly1)
len2 = len(poly2)
# Handle trivial cases: multiplication by zero polynomial
if min(len1, len2) == 0 or (len1 == 1 and poly1[0] == 0) or (len2 == 1 and poly2[0] == 0):
# Return a zero polynomial of the correct truncated length
return [0]*(degree_limit+1)
# Optimization: multiplication by constant polynomial 1
if len1 == 1 and poly1[0] == 1:
res = poly2[:degree_limit+1] # Take up to degree limit
# Pad with zeros if shorter than required length
if len(res) < degree_limit + 1: res.extend([0]*(degree_limit + 1 - len(res)))
return res
if len2 == 1 and poly2[0] == 1:
res = poly1[:degree_limit+1]
if len(res) < degree_limit + 1: res.extend([0]*(degree_limit + 1 - len(res)))
return res
# Calculate degree of the result polynomial
# Degree is sum of degrees. Length is degree + 1.
# Degree of poly of length L is L-1.
res_degree = (len1 - 1) + (len2 - 1)
# Handle cases where one or both polynomials are constants (degree -1 effectively)
if res_degree < 0: res_degree = 0
# Determine NTT length: smallest power of 2 >= required length for result
target_len = 1
# Required length is degree + 1 = (len1-1 + len2-1) + 1 = len1 + len2 - 1
while target_len < len1 + len2 - 1:
target_len <<= 1
# Ensure minimum length of 1
if target_len == 0: target_len = 1
# Pad polynomials with zeros to target length
pad1 = poly1 + [0] * (target_len - len1)
pad2 = poly2 + [0] * (target_len - len2)
# Perform forward NTT
ntt1 = ntt(pad1, False)
ntt2 = ntt(pad2, False)
# Pointwise multiplication in frequency domain
ntt_res = [(ntt1[i] * ntt2[i]) % MOD for i in range(target_len)]
# Perform inverse NTT
res_poly = ntt(ntt_res, True)
# Truncate result polynomial to the specified degree_limit
final_len = min(res_degree, degree_limit) + 1
final_poly = res_poly[:final_len]
# Pad with zeros if the resulting polynomial is shorter than required length
if len(final_poly) < final_len:
final_poly.extend([0]*(final_len - len(final_poly)))
return final_poly
# Polynomial exponentiation using modular squaring
# Computes poly^k % MOD, truncated to degree_limit
def poly_pow(poly, k, degree_limit):
# Initialize result polynomial as 1 (coefficient list [1]) padded to degree_limit
res = [1] + [0]*degree_limit
# Make a copy of the base polynomial and pad/truncate it
base = poly[:]
if len(base) < degree_limit + 1: base.extend([0]*(degree_limit+1 - len(base)))
else: base = base[:degree_limit+1]
# Exponentiation by squaring loop
while k > 0:
if k % 2 == 1: # If k is odd, multiply result by current base
res = ntt_multiply(res, base, degree_limit)
# Square the base for the next iteration (if k > 1)
if k > 1: # Avoid squaring unnecessarily at the end
base = ntt_multiply(base, base, degree_limit)
k //= 2
# Return result truncated to degree_limit + 1 length
return res[:degree_limit+1]
# Euler's totient function phi(m)
def get_phi(m):
if m == 1: return 1
res = m
p = 2
temp_m = m
# Iterate through potential prime factors up to sqrt(m)
while p * p <= temp_m:
if temp_m % p == 0:
# Found a prime factor p
while temp_m % p == 0: # Remove all occurrences of p
temp_m //= p
# Apply the formula component for p: res = res * (1 - 1/p) = res - res/p
res -= res // p
p += 1
# If temp_m > 1 after the loop, it must be a prime factor itself
if temp_m > 1:
res -= res // temp_m
return res # Returns integer value
# --- Main Program Logic ---
# Read input N and K
N, K = map(int, sys.stdin.readline().split())
# Read doll counts A_i
A = list(map(int, sys.stdin.readline().split()))
# Precompute factorials and their inverses up to K
fact, invfact = precompute_factorials(K)
# Find all divisors of K
divisors = []
for i in range(1, int(K**0.5) + 1):
if K % i == 0:
divisors.append(i)
if i*i != K: # Avoid adding the square root twice
divisors.append(K//i)
total_sum = 0 # Accumulator for Burnside's Lemma sum
# Cache for computed base polynomials P_{limit, d}(x) = sum_{p=0..limit} x^p/p!
# Key is 'limit'. Helps avoid recomputing same polynomial within loop over d.
base_poly_cache = {}
# Iterate through each divisor d of K
for d in divisors:
# Calculate cycle length L = K/d
L = K // d
# Basic sanity check: L should be positive since K>=2, d>=1
if L == 0: continue
# Count occurrences of each value B_i = floor(A_i / L)
counts = {}
for x in A:
B_i = x // L
if B_i > 0: # Only consider B_i > 0; P(x)=1 for B_i=0
counts[B_i] = counts.get(B_i, 0) + 1
# Initialize the product polynomial G_d(x) as 1. Pad to length d+1.
final_poly = [1] + [0]*d
# Get distinct positive B_i values that occurred
distinct_B = sorted(counts.keys())
# Compute the product G_d(x) = product_{b} (P_{b,d}(x))^{counts[b]}
for b in distinct_B:
count = counts[b] # Number of times this B_i value occurred
# Determine the degree limit for this base polynomial: min(b, d)
limit = min(b, d)
# Retrieve or compute the base polynomial P_{limit, d}(x)
if limit not in base_poly_cache:
# Compute P_{limit, d}(x) = sum_{p=0..limit} x^p/p! using precomputed inverse factorials
current_poly_coeffs = []
for p in range(limit + 1):
current_poly_coeffs.append(invfact[p])
base_poly_cache[limit] = current_poly_coeffs # Store in cache
current_poly = base_poly_cache[limit] # Get from cache
# Compute (P_{limit, d}(x))^count using polynomial exponentiation
powered_poly = poly_pow(current_poly, count, d)
# Multiply this into the accumulated product polynomial G_d(x)
final_poly = ntt_multiply(final_poly, powered_poly, d)
# Extract the coefficient of x^d from the final product polynomial G_d(x)
coeff_xd = 0
if len(final_poly) > d: # Ensure the list is long enough
coeff_xd = final_poly[d]
# Compute F(d) = (coefficient of x^d / d!) * d! = (coefficient of x^d) * d!
# Note: F(d) is defined as the coefficient of x^d/d! in the EGF product.
# This simplifies to F(d) = [x^d]G_d(x) * d!
F_d = (coeff_xd * fact[d]) % MOD
# Compute Euler's totient function phi(K/d)
phi_val = get_phi(K // d)
# Add the term phi(K/d) * F(d) to the total sum (using Burnside's Lemma structure)
term = (phi_val * F_d) % MOD
total_sum = (total_sum + term) % MOD
# Final answer is (1/K) * total_sum modulo MOD
K_inv = inverse(K) # Compute modular inverse of K
ans = (total_sum * K_inv) % MOD
print(ans) # Output the final result
solve()
qwewe