結果
問題 | No.2959 Dolls' Tea Party |
ユーザー |
![]() |
提出日時 | 2025-03-26 16:00:05 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 5,811 bytes |
コンパイル時間 | 252 ms |
コンパイル使用メモリ | 82,576 KB |
実行使用メモリ | 146,612 KB |
最終ジャッジ日時 | 2025-03-26 16:01:14 |
合計ジャッジ時間 | 6,806 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 5 TLE * 1 -- * 27 |
ソースコード
import sysMOD = 998244353def main():input = sys.stdin.read().split()ptr = 0N, K = int(input[ptr]), int(input[ptr+1])ptr +=2A = list(map(int, input[ptr:ptr+N]))# Precompute factorial and inverse factorial up to s_max + K_maxs_max = Kfact = [1]*(s_max + 1)for i in range(1, s_max+1):fact[i] = fact[i-1] * i % MODinv_fact = [1]*(s_max +1)inv_fact[s_max] = pow(fact[s_max], MOD-2, MOD)for i in range(s_max-1, -1, -1):inv_fact[i] = inv_fact[i+1] * (i+1) % MOD# Precompute divisors of Kdef get_divisors(n):divisors = set()for i in range(1, int(n**0.5)+1):if n %i ==0:divisors.add(i)divisors.add(n//i)return sorted(divisors)divisors = get_divisors(K)divisors.sort()# Precompute phi for each divisorphi = {}for d in divisors:m = dres = mfor i in range(2, int(m**0.5)+1):if i*i >m:breakif m %i ==0:res = res //i * (i-1)while m %i ==0:m = m//iif m>1:res = res //m * (m-1)phi[d] = res# Precompute frequency array for A_i < KK_floor = Kfreq = [0]*(K_floor)C = 0for a in A:if a >= K:C +=1else:freq[a] +=1# Precompute for each divisor dtotal =0for d in divisors:s = K // dif s ==0:continue# Compute m_i for all i where A_i <Ksum_floor_high =0sum_floor_low =0for x in range(K_floor):sum_floor_low += freq[x] * (x //d)# sum_floor_high is sum over A_i >=K of floor(A_i/d)sum_floor_high =0for a in A:if a >=K:sum_floor_high += (a //d)total_B = sum_floor_low + sum_floor_highif total_B < s:continue# Now compute the number of sequences of length s, with each type count <= m_i# m_i is floor(A_i/d)# Split into two parts: m_i >=s and m_i <s# Part 1: m_i >=s (i.e., A_i >=d*s = K)# C is the number of such types# Part 2: m_i <s (A_i < K)# Compute their generating functions# Part 1: (sum_{k=0}^s x^k /k! )^C# Part 2: product_{i: m_i <s} (sum_{k=0}^m_i x^k/k! )dp = [0]*(s+1)dp[0] =1# Handle part 1: (sum_{k=0}^s x^k/k! )^C# We need to raise the polynomial to the Cth power# First, compute the polynomial (sum x^k/k! for k=0 to s)poly_part1 = [0]*(s+1)for k in range(s+1):poly_part1[k] = inv_fact[k]# Compute poly_part1^C using exponentiation by squaringdef poly_pow(p, exponent, max_degree):result = [0]*(max_degree+1)result[0] =1current = p.copy()while exponent >0:if exponent %2 ==1:# Multiply result by currentnew_result = [0]*(max_degree+1)for i in range(max_degree+1):if result[i] ==0:continuefor j in range(max_degree+1 -i):new_result[i+j] = (new_result[i+j] + result[i] * current[j]) % MODresult = new_result# Square currentnew_current = [0]*(max_degree*2 +1)for i in range(max_degree+1):if current[i] ==0:continuefor j in range(max_degree+1 -i):new_current[i+j] = (new_current[i+j] + current[i] * current[j]) % MODcurrent = new_current[:max_degree+1]exponent = exponent //2return resultpart1 = poly_pow(poly_part1, C, s)# Now handle part2: types with A_i < K, m_i = floor(A_i/d)part2 = [0]*(s+1)part2[0] =1# Precompute m_i for all i where A_i <K# Group by m_igroups = {}for x in range(K_floor):cnt = freq[x]if cnt ==0:continuem_i = x //dif m_i >=s:continue # already handled in part1if m_i not in groups:groups[m_i] =0groups[m_i] += cnt# Process each group in part2for t in groups:cnt = groups[t]if cnt ==0:continue# The polynomial for this group is (sum_{k=0}^t x^k/k! )^cntpoly = [inv_fact[k] if k <=t else 0 for k in range(s+1)]poly = poly_pow(poly, cnt, s)# Multiply into part2new_part2 = [0]*(s+1)for i in range(s+1):if part2[i] ==0:continuefor j in range(s+1 -i):new_part2[i+j] = (new_part2[i+j] + part2[i] * poly[j]) % MODpart2 = new_part2# Combine part1 and part2combined = [0]*(s+1)for i in range(s+1):if part1[i] ==0:continuefor j in range(s+1 -i):combined[i+j] = (combined[i+j] + part1[i] * part2[j]) % MOD# Get the coefficient of x^scoeff = combined[s]f_d = coeff * fact[s] % MOD# Add to totaltotal = (total + phi[d] * f_d) % MOD# Compute answeranswer = total * pow(K, MOD-2, MOD) % MODprint(answer)if __name__ == '__main__':main()